diff --git a/lib/ansible/module_utils/vyos.py b/lib/ansible/module_utils/vyos.py index d18e5f7b8e3..143871e17ef 100644 --- a/lib/ansible/module_utils/vyos.py +++ b/lib/ansible/module_utils/vyos.py @@ -34,11 +34,13 @@ _DEVICE_CONFIGS = {} vyos_argument_spec = { 'host': dict(), 'port': dict(type='int'), + 'username': dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])), 'password': dict(fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD']), no_log=True), 'ssh_keyfile': dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'), - 'timeout': dict(type='int', default=10), - 'provider': dict(type='dict'), + + 'timeout': dict(type='int'), + 'provider': dict(type='dict', no_log=True), } def check_args(module, warnings): diff --git a/lib/ansible/plugins/action/vyos.py b/lib/ansible/plugins/action/vyos.py index a97cad69569..8a7aefc158d 100644 --- a/lib/ansible/plugins/action/vyos.py +++ b/lib/ansible/plugins/action/vyos.py @@ -31,13 +31,19 @@ from ansible.module_utils.vyos import vyos_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils._text import to_bytes +try: + from __main__ import display +except ImportError: + from ansible.utils.display import Display + display = Display() + class ActionModule(_ActionModule): def run(self, tmp=None, task_vars=None): if self._play_context.connection != 'local': return dict( - fail=True, + failed=True, msg='invalid connection specified, expected connection=local, ' 'got %s' % self._play_context.connection ) @@ -46,18 +52,29 @@ class ActionModule(_ActionModule): pc = copy.deepcopy(self._play_context) pc.connection = 'network_cli' + pc.network_os = 'vyos' pc.port = provider['port'] or self._play_context.port or 22 pc.remote_user = provider['username'] or self._play_context.connection_user pc.password = provider['password'] or self._play_context.password + pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file + pc.timeout = provider['timeout'] or self._play_context.timeout + + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) socket_path = self._get_socket_path(pc) if not os.path.exists(socket_path): # start the connection if it isn't started - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - connection.exec_command('EXEC: show version') + rc, out, err = connection.exec_command('open_shell()') + if rc != 0: + return {'failed': True, 'msg': 'unable to connect to control socket'} task_vars['ansible_socket'] = socket_path + result = super(ActionModule, self).run(tmp, task_vars) + + display.vvv('closing cli shell', self._play_context.remote_addr) + connection.exec_command('close_shell()') + return super(ActionModule, self).run(tmp, task_vars) def _get_socket_path(self, play_context): @@ -68,7 +85,7 @@ class ActionModule(_ActionModule): def load_provider(self): provider = self._task.args.get('provider', {}) - for key, value in iteritems(ios_argument_spec): + for key, value in iteritems(vyos_argument_spec): if key != 'provider' and key not in provider: if key in self._task.args: provider[key] = self._task.args[key]