diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index c27fa416218..59e808afebe 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -303,6 +303,26 @@ def get_distribution_version(): distribution_version = None return distribution_version +def get_all_subclasses(cls): + ''' + used by modules like Hardware or Network fact classes to retrieve all subclasses of a given class. + __subclasses__ return only direct sub classes. This one go down into the class tree. + ''' + # Retrieve direct subclasses + subclasses = cls.__subclasses__() + to_visit = list(subclasses) + # Then visit all subclasses + while to_visit: + for sc in to_visit: + # The current class is now visited, so remove it from list + to_visit.remove(sc) + # Appending all subclasses to visit and keep a reference of available class + for ssc in sc.__subclasses__(): + subclasses.append(ssc) + to_visit.append(ssc) + return subclasses + + def load_platform_subclass(cls, *args, **kwargs): ''' used by modules like User to have different implementations based on detected platform. See User @@ -315,11 +335,11 @@ def load_platform_subclass(cls, *args, **kwargs): # get the most specific superclass for this platform if distribution is not None: - for sc in cls.__subclasses__(): + for sc in get_all_subclasses(cls): if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform: subclass = sc if subclass is None: - for sc in cls.__subclasses__(): + for sc in get_all_subclasses(cls): if sc.platform == this_platform and sc.distribution is None: subclass = sc if subclass is None: diff --git a/lib/ansible/module_utils/facts.py b/lib/ansible/module_utils/facts.py index 420dbf573d2..8c1c4be7a94 100644 --- a/lib/ansible/module_utils/facts.py +++ b/lib/ansible/module_utils/facts.py @@ -32,6 +32,7 @@ import datetime import getpass import pwd import ConfigParser +from basic import get_all_subclasses # py2 vs py3; replace with six via ziploader try: @@ -867,7 +868,7 @@ class Hardware(Facts): def __new__(cls, *arguments, **keyword): subclass = cls - for sc in Hardware.__subclasses__(): + for sc in get_all_subclasses(Hardware): if sc.platform == platform.system(): subclass = sc return super(cls, subclass).__new__(subclass, *arguments, **keyword) @@ -1949,23 +1950,9 @@ class Network(Facts): def __new__(cls, *arguments, **keyword): subclass = cls - # Retrieve direct subclasses - to_visit = Network.__subclasses__() - # Then visit all subclasses - while to_visit: - for sc in to_visit: - # Check if current class is the good one - if sc.platform == platform.system(): - subclass = sc - to_visit = [] - break - # The current class is now visited, so remove it from list - to_visit.remove(sc) - # Appending all subclasses to visit and keep a reference of available class - for ssc in sc.__subclasses__(): - to_visit.append(ssc) - - # Now, return corresponding subclass + for sc in get_all_subclasses(Network): + if sc.platform == platform.system(): + subclass = sc return super(cls, subclass).__new__(subclass, *arguments, **keyword) def populate(self): @@ -2725,7 +2712,7 @@ class Virtual(Facts): def __new__(cls, *arguments, **keyword): subclass = cls - for sc in Virtual.__subclasses__(): + for sc in get_all_subclasses(Virtual): if sc.platform == platform.system(): subclass = sc return super(cls, subclass).__new__(subclass, *arguments, **keyword)