diff --git a/lib/ansible/module_utils/eos.py b/lib/ansible/module_utils/eos.py index 39771d0f9c4..3655820d321 100644 --- a/lib/ansible/module_utils/eos.py +++ b/lib/ansible/module_utils/eos.py @@ -19,12 +19,15 @@ import re -from ansible.module_utils.basic import json, get_exception, AnsibleModule -from ansible.module_utils.network import Command, NetCli, NetworkError, get_module +from ansible.module_utils.basic import json, AnsibleModule, get_exception +from ansible.module_utils.network import NetCli, NetworkError, NetworkModule, Command from ansible.module_utils.network import add_argument, register_transport, to_list from ansible.module_utils.netcfg import NetworkConfig from ansible.module_utils.urls import fetch_url, url_argument_spec +# temporary fix until modules are update. to be removed before 2.2 final +from ansible.module_utils.network import get_module + EAPI_FORMATS = ['json', 'text'] add_argument('use_ssl', dict(default=True, type='bool')) diff --git a/lib/ansible/module_utils/network.py b/lib/ansible/module_utils/network.py index 4d80de7b039..0a8b9d83cbb 100644 --- a/lib/ansible/module_utils/network.py +++ b/lib/ansible/module_utils/network.py @@ -48,24 +48,6 @@ def to_list(val): else: return list() -def connect(module): - try: - if not module.connected: - module.connection.connect(module.params) - if module.params['authorize']: - module.connection.authorize(module.params) - except NetworkError: - exc = get_exception() - module.fail_json(msg=exc.message) - -def disconnect(module): - try: - if module.connected: - module.connection.disconnect() - except NetworkError: - exc = get_exception() - module.fail_json(msg=exc.message) - class Command(object): @@ -157,15 +139,39 @@ class NetworkError(Exception): class NetworkModule(AnsibleModule): def __init__(self, *args, **kwargs): + connect_on_load = kwargs.pop('connect_on_load', True) + + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec['transport']['choices'] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get('argument_spec'): + argument_spec.update(kwargs['argument_spec']) + kwargs['argument_spec'] = argument_spec + super(NetworkModule, self).__init__(*args, **kwargs) + self.connection = None self._cli = None self._config = None + try: + transport = self.params['transport'] or '__default__' + cls = NET_CONNECTIONS[transport] + self.connection = cls() + except KeyError: + self.fail_json(msg='Unknown transport or no default transport specified') + except (TypeError, NetworkError): + exc = get_exception() + self.fail_json(msg=exc.message) + + if connect_on_load: + self.connect() + @property def cli(self): if not self.connected: - connect(self) + self.connect() if self._cli: return self._cli self._cli = Cli(self.connection) @@ -174,7 +180,7 @@ class NetworkModule(AnsibleModule): @property def config(self): if not self.connected: - connect(self) + self.connect() if self._config: return self._config self._config = Config(self.connection) @@ -193,6 +199,24 @@ class NetworkModule(AnsibleModule): if self.params.get(key) is None and value is not None: self.params[key] = value + def connect(self): + try: + if not self.connected: + self.connection.connect(self.params) + if self.params['authorize']: + self.connection.authorize(self.params) + except NetworkError: + exc = get_exception() + self.fail_json(msg=exc.message) + + def disconnect(self): + try: + if self.connected: + self.connection.disconnect() + except NetworkError: + exc = get_exception() + self.fail_json(msg=exc.message) + class NetCli(object): """Basic paramiko-based ssh transport any NetworkModule can use.""" @@ -249,32 +273,6 @@ class NetCli(object): raise NetworkError(exc.message, commands=commands) -def get_module(connect_on_load=True, **kwargs): - argument_spec = NET_TRANSPORT_ARGS.copy() - argument_spec['transport']['choices'] = NET_CONNECTIONS.keys() - argument_spec.update(NET_CONNECTION_ARGS.copy()) - - if kwargs.get('argument_spec'): - argument_spec.update(kwargs['argument_spec']) - kwargs['argument_spec'] = argument_spec - - module = NetworkModule(**kwargs) - - try: - transport = module.params['transport'] or '__default__' - cls = NET_CONNECTIONS[transport] - module.connection = cls() - except KeyError: - module.fail_json(msg='Unknown transport or no default transport specified') - except (TypeError, NetworkError): - exc = get_exception() - module.fail_json(msg=exc.message) - - if connect_on_load: - connect(module) - - return module - def register_transport(transport, default=False): def register(cls): NET_CONNECTIONS[transport] = cls @@ -286,3 +284,9 @@ def register_transport(transport, default=False): def add_argument(key, value): NET_CONNECTION_ARGS[key] = value +def get_module(*args, **kwargs): + # This is a temporary factory function to avoid break all modules + # until the modules are updated. This function *will* be removed + # before 2.2 final + return NetworkModule(*args, **kwargs) +