diff --git a/system/hostname.py b/system/hostname.py index 5ed9e0a7c2c..5e98d5bdb22 100644 --- a/system/hostname.py +++ b/system/hostname.py @@ -48,6 +48,7 @@ from distutils.version import LooseVersion # import module snippets from ansible.module_utils.basic import * +from ansible.module_utils.facts import * class UnimplementedStrategy(object): @@ -94,9 +95,12 @@ class Hostname(object): return load_platform_subclass(Hostname, args, kwargs) def __init__(self, module): - self.module = module - self.name = module.params['name'] - self.strategy = self.strategy_class(module) + self.module = module + self.name = module.params['name'] + if self.platform == 'Linux' and Facts(module).is_systemd_managed(): + self.strategy = SystemdStrategy(module) + else: + self.strategy = self.strategy_class(module) def get_current_hostname(self): return self.strategy.get_current_hostname() @@ -512,9 +516,7 @@ class SLESHostname(Hostname): platform = 'Linux' distribution = 'Suse linux enterprise server ' distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("12"): - strategy_class = SystemdStrategy - elif distribution_version and LooseVersion("10") <= LooseVersion(distribution_version) <= LooseVersion("12"): + if distribution_version and LooseVersion("10") <= LooseVersion(distribution_version) <= LooseVersion("12"): strategy_class = SLESStrategy else: strategy_class = UnimplementedStrategy @@ -537,65 +539,37 @@ class RedHat5Hostname(Hostname): class RedHatServerHostname(Hostname): platform = 'Linux' distribution = 'Red hat enterprise linux server' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class RedHatWorkstationHostname(Hostname): platform = 'Linux' distribution = 'Red hat enterprise linux workstation' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class CentOSHostname(Hostname): platform = 'Linux' distribution = 'Centos' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class CentOSLinuxHostname(Hostname): platform = 'Linux' distribution = 'Centos linux' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class ScientificHostname(Hostname): platform = 'Linux' distribution = 'Scientific' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class ScientificLinuxHostname(Hostname): platform = 'Linux' distribution = 'Scientific linux' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class OracleLinuxHostname(Hostname): platform = 'Linux' distribution = 'Oracle linux server' - distribution_version = get_distribution_version() - if distribution_version and LooseVersion(distribution_version) >= LooseVersion("7"): - strategy_class = SystemdStrategy - else: - strategy_class = RedHatStrategy + strategy_class = RedHatStrategy class AmazonLinuxHostname(Hostname): platform = 'Linux' @@ -658,7 +632,7 @@ class FreeBSDHostname(Hostname): def main(): module = AnsibleModule( argument_spec = dict( - name=dict(required=True, type='str') + name=dict(required=True) ) ) @@ -682,4 +656,5 @@ def main(): ansible_fqdn=socket.getfqdn(), ansible_domain='.'.join(socket.getfqdn().split('.')[1:]))) -main() +if __name__ == '__main__': + main()