Home » Python » How can I find all subclasses of a class given its name?

How can I find all subclasses of a class given its name?

Posted by: admin November 1, 2017 Leave a comment

Questions:

I need a working approach of getting all classes that are inherited from the base class in Python.

Answers:

New-style classes (i.e. subclassed from object, which is the default in Python 3) have a __subclasses__ method which returns the subclasses:

class Foo(object): pass
class Bar(Foo): pass
class Baz(Foo): pass
class Bing(Bar): pass

Here are the names of the subclasses:

print([cls.__name__ for cls in vars()['Foo'].__subclasses__()])
# ['Bar', 'Baz']

Here are the subclasses themselves:

print(vars()['Foo'].__subclasses__())
# [<class '__main__.Bar'>, <class '__main__.Baz'>]

Confirmation that the subclasses do indeed list Foo as their base:

for cls in vars()['Foo'].__subclasses__():
    print(cls.__base__)
# <class '__main__.Foo'>
# <class '__main__.Foo'>

Note if you want subsubclasses, you’ll have to recurse:

def all_subclasses(cls):
    return cls.__subclasses__() + [g for s in cls.__subclasses__()
                                   for g in all_subclasses(s)]

print(all_subclasses(vars()['Foo']))
# [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

Questions:
Answers:

If you just want direct subclasses then .__subclasses__() works fine. If you want all subclasses, subclasses of subclasses, and so on, you’ll need a function to do that for you.

Here’s a simple, readable function that recursively finds all subclasses of a given class:

def get_all_subclasses(cls):
    all_subclasses = []

    for subclass in cls.__subclasses__():
        all_subclasses.append(subclass)
        all_subclasses.extend(get_all_subclasses(subclass))

    return all_subclasses

Questions:
Answers:

The simplest solution in general form:

def get_subclasses(cls):
    for subclass in cls.__subclasses__():
        yield from get_subclasses(subclass)
        yield subclass

And a classmethod in case you have a single class where you inherit from:

@classmethod
def get_subclasses(cls):
    for subclass in cls.__subclasses__():
        yield from subclass.get_subclasses()
        yield subclass

Questions:
Answers:

Python 3.6__init_subclass__

As other answer mentioned you can check the __subclasses__ attribute to get the list of subclasses, since python 3.6 you can modify this attribute creation by overriding the __init_subclass__ method.

class PluginBase:
    subclasses = []

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.subclasses.append(cls)

class Plugin1(PluginBase):
    pass

class Plugin2(PluginBase):
    pass

This way, if you know what you’re doing, you can override the behavior of of __subclasses__ and omit/add subclasses from this list.

Questions:
Answers:

FWIW, here’s what I meant about @unutbu’s answer only working with locally defined classes — and that using eval() instead of vars() would make it work with any accessible class, not only those defined in the current scope.

For those who dislike using eval(), a way is also shown to avoid it.

First here’s a concrete example demonstrating the potential problem with using vars():

class Foo(object): pass
class Bar(Foo): pass
class Baz(Foo): pass
class Bing(Bar): pass

# unutbu's approach
def all_subclasses(cls):
    return cls.__subclasses__() + [g for s in cls.__subclasses__()
                                   for g in all_subclasses(s)]

print(all_subclasses(vars()['Foo']))  # Fine because  Foo is in scope
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

def func():  # won't work because Foo class is not locally defined
    print(all_subclasses(vars()['Foo']))

try:
    func()  # not OK because Foo is not local to func()
except Exception as e:
    print('calling func() raised exception: {!r}'.format(e))
    # -> calling func() raised exception: KeyError('Foo',)

print(all_subclasses(eval('Foo')))  # OK
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

# using eval('xxx') instead of vars()['xxx']
def func2():
    print(all_subclasses(eval('Foo')))

func2()  # Works
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

This could be improved by moving the eval('ClassName') down into the function defined, which makes using it easier without loss of the additional generality gained by using eval() which unlike vars() is not context-sensitive:

# easier to use version
def all_subclasses2(classname):
    direct_subclasses = eval(classname).__subclasses__()
    return direct_subclasses + [g for s in direct_subclasses
                                    for g in all_subclasses2(s.__name__)]

# pass 'xxx' instead of eval('xxx')
def func_ez():
    print(all_subclasses2('Foo'))  # simpler

func_ez()
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

Lastly, it’s possible, and perhaps even important in some cases, to avoid using eval() for security reasons, so here’s a version without it:

def get_all_subclasses(cls):
    """ Generator of all a class's subclasses. """
    try:
        for subclass in cls.__subclasses__():
            yield subclass
            for subclass in get_all_subclasses(subclass):
                yield subclass
    except TypeError:
        return

def all_subclasses3(classname):
    for cls in get_all_subclasses(object):
        if cls.__name__.split('.')[-1] == classname:
            break
    else:
        raise ValueError('class %s not found' % classname)
    direct_subclasses = cls.__subclasses__()
    return direct_subclasses + [g for s in direct_subclasses
                                    for g in all_subclasses3(s.__name__)]

# no eval('xxx')
def func3():
    print(all_subclasses3('Foo'))

func3()  # Also works
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

Questions:
Answers:

A much shorter version for getting a list of all subclasses:

from itertools import chain

def subclasses(cls):
    return list(
        chain.from_iterable(
            [list(chain.from_iterable([[x], subclasses(x)])) for x in cls.__subclasses__()]
        )
    )

Questions:
Answers:

This isn’t as good an answer as using the special built-in__subclasses__()class method which @unutbu mentions, so I present it merely as an exercise. Thesubclasses()function defined returns a dictionary which maps all the subclass names to the subclasses themselves.

def traced_subclass(baseclass):
    class _SubclassTracer(type):
        def __new__(cls, classname, bases, classdict):
            obj = type(classname, bases, classdict)
            if baseclass in bases: # sanity check
                attrname = '_%s__derived' % baseclass.__name__
                derived = getattr(baseclass, attrname, {})
                derived.update( {classname:obj} )
                setattr(baseclass, attrname, derived)
             return obj
    return _SubclassTracer

def subclasses(baseclass):
    attrname = '_%s__derived' % baseclass.__name__
    return getattr(baseclass, attrname, None)

class BaseClass(object):
    pass

class SubclassA(BaseClass):
    __metaclass__ = traced_subclass(BaseClass)

class SubclassB(BaseClass):
    __metaclass__ = traced_subclass(BaseClass)

print subclasses(BaseClass)

Output:

{'SubclassB': <class '__main__.SubclassB'>,
 'SubclassA': <class '__main__.SubclassA'>}

Questions:
Answers:

Here’s a version without recursion:

def get_subclasses_gen(cls):

    def _subclasses(classes, seen):
        while True:
            subclasses = sum((x.__subclasses__() for x in classes), [])
            yield from classes
            yield from seen
            found = []
            if not subclasses:
                return

            classes = subclasses
            seen = found

    return _subclasses([cls], [])

This differs from other implementations in that it returns the original class.
This is because it makes the code simpler and:

class Ham(object):
    pass

assert(issubclass(Ham, Ham)) # True

If get_subclasses_gen looks a bit weird that’s because it was created by converting a tail-recursive implementation into a looping generator:

def get_subclasses(cls):

    def _subclasses(classes, seen):
        subclasses = sum(*(frozenset(x.__subclasses__()) for x in classes))
        found = classes + seen
        if not subclasses:
            return found

        return _subclasses(subclasses, found)

    return _subclasses([cls], [])

Questions:
Answers:

I cannot imagine a real world use case for it, but a robust way (even on Python 2 old style classes) would be to scan the globals namespace:

def has_children(cls):
    g = globals().copy()   # use a copy to make sure it will not change during iteration
    g.update(locals())     # add local symbols
    for k, v in g.items(): # iterate over all globals object
        try:
            if (v is not cls) and issubclass(v, cls): # found a strict sub class?
                return True
        except TypeError:  # issubclass raises a TypeError if arg is not a class...
            pass
    return False

It works on Python 2 new style classes and Python 3 classes as well as on Python 2 classic classes