diff --git a/bin/ansible-connection b/bin/ansible-connection index c49354a30d2..de118aecb81 100755 --- a/bin/ansible-connection +++ b/bin/ansible-connection @@ -1,6 +1,6 @@ #!/usr/bin/env python -# (c) 2016, Ansible, Inc. +# (c) 2017, Ansible, Inc. # # This file is part of Ansible # @@ -33,18 +33,17 @@ import os import shlex import signal import socket -import struct import sys import time import traceback -import syslog import datetime -import logging +import errno from ansible import constants as C from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils.six import PY3 from ansible.module_utils.six.moves import cPickle +from ansible.module_utils.connection import send_data, recv_data from ansible.playbook.play_context import PlayContext from ansible.plugins import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe @@ -88,33 +87,11 @@ def do_fork(): except OSError as e: sys.exit(1) -def send_data(s, data): - packed_len = struct.pack('!Q', len(data)) - return s.sendall(packed_len + data) - -def recv_data(s): - header_len = 8 # size of a packed unsigned long long - data = b"" - while len(data) < header_len: - d = s.recv(header_len - len(data)) - if not d: - return None - data += d - data_len = struct.unpack('!Q', data[:header_len])[0] - data = data[header_len:] - while len(data) < data_len: - d = s.recv(data_len - len(data)) - if not d: - return None - data += d - return data - class Server(): - def __init__(self, path, play_context): - - self.path = path + def __init__(self, socket_path, play_context): + self.socket_path = socket_path self.play_context = play_context display.display( @@ -123,135 +100,163 @@ class Server(): log_only=True ) - display.display('control socket path is %s' % path, log_only=True) + display.display('control socket path is %s' % socket_path, log_only=True) display.display('current working directory is %s' % os.getcwd(), log_only=True) self._start_time = datetime.datetime.now() display.display("using connection plugin %s" % self.play_context.connection, log_only=True) - self.conn = connection_loader.get(play_context.connection, play_context, sys.stdin) - self.conn._connect() - if not self.conn.connected: + self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin) + self.connection._connect() + + if not self.connection.connected: raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr) connection_time = datetime.datetime.now() - self._start_time display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True) self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.bind(path) + self.socket.bind(self.socket_path) self.socket.listen(1) - - signal.signal(signal.SIGALRM, self.alarm_handler) - - def dispatch(self, obj, name, *args, **kwargs): - meth = getattr(obj, name, None) - if meth: - return meth(*args, **kwargs) - - def alarm_handler(self, signum, frame): - ''' - Alarm handler - ''' - # FIXME: this should also set internal flags for other - # areas of code to check, so they can terminate - # earlier than the socket going back to the accept - # call and failing there. - # - # hooks the connection plugin to handle any cleanup - self.dispatch(self.conn, 'alarm_handler', signum, frame) - self.socket.close() + display.display('local socket is set to listening', log_only=True) def run(self): try: while True: - # set the alarm, if we don't get an accept before it - # goes off we exit (via an exception caused by the socket - # getting closed while waiting on accept()) - # FIXME: is this the best way to exit? as noted above in the - # handler we should probably be setting a flag to check - # here and in other parts of the code + signal.signal(signal.SIGALRM, self.connect_timeout) + signal.signal(signal.SIGTERM, self.handler) signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT) - try: - (s, addr) = self.socket.accept() - display.display('incoming request accepted on persistent socket', log_only=True) - # clear the alarm - # FIXME: potential race condition here between the accept and - # time to this call. - signal.alarm(0) - except: - break + + (s, addr) = self.socket.accept() + display.display('incoming request accepted on persistent socket', log_only=True) + signal.alarm(0) while True: data = recv_data(s) if not data: break + signal.signal(signal.SIGALRM, self.command_timeout) signal.alarm(self.play_context.timeout) + op = data.split(':')[0] + display.display('socket operation is %s' % op, log_only=True) + + method = getattr(self, 'do_%s' % op, None) + rc = 255 - try: - if data.startswith(b'EXEC: '): - display.display("socket operation is EXEC", log_only=True) - cmd = data.split(b'EXEC: ')[1] - (rc, stdout, stderr) = self.conn.exec_command(cmd) - elif data.startswith(b'PUT: ') or data.startswith(b'FETCH: '): - (op, src, dst) = shlex.split(to_native(data)) - stdout = stderr = '' - try: - if op == 'FETCH:': - display.display("socket operation is FETCH", log_only=True) - self.conn.fetch_file(src, dst) - elif op == 'PUT:': - display.display("socket operation is PUT", log_only=True) - self.conn.put_file(src, dst) - rc = 0 - except: - pass - elif data.startswith(b'CONTEXT: '): - display.display("socket operation is CONTEXT", log_only=True) - pc_data = data.split(b'CONTEXT: ', 1)[1] - - if PY3: - pc_data = cPickle.loads(pc_data, encoding='bytes') - else: - pc_data = cPickle.loads(pc_data) - - pc = PlayContext() - pc.deserialize(pc_data) - - self.dispatch(self.conn, 'update_play_context', pc) - continue - else: - display.display("socket operation is UNKNOWN", log_only=True) - stdout = '' - stderr = 'Invalid action specified' - except: - stdout = '' - stderr = traceback.format_exc() + stdout = stderr = '' + + if not method: + stderr = 'Invalid action specified' + else: + rc, stdout, stderr = method(data) signal.alarm(0) - display.display("socket operation completed with rc %s" % rc, log_only=True) + display.display('socket operation completed with rc %s' % rc, log_only=True) send_data(s, to_bytes(rc)) send_data(s, to_bytes(stdout)) send_data(s, to_bytes(stderr)) + s.close() + except Exception as e: - display.display(traceback.format_exc(), log_only=True) + # socket.accept() will raise EINTR if the socket.close() is called + if e.errno != errno.EINTR: + display.display(traceback.format_exc(), log_only=True) + finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated + self.shutdown() end_time = datetime.datetime.now() delta = end_time - self._start_time - display.display('shutting down control socket, connection was active for %s secs' % delta, log_only=True) - try: - self.conn.close() + display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True) + + def connect_timeout(self, signum, frame): + display.display('connect timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True) + self.shutdown() + + def command_timeout(self, signum, frame): + display.display('commnad timeout triggered, timeout value is %s secs' % self.play_context.timeout, log_only=True) + self.shutdown() + + def handler(self, signum, frame): + display.display('signal handler called with signal %s' % signum, log_only=True) + self.shutdown() + + def shutdown(self): + display.display('shutdown persistent connection requested', log_only=True) + + if not os.path.exists(self.socket_path): + display.display('persistent connection is not active', log_only=True) + return + + try: + if self.socket: + display.display('closing local listener', log_only=True) self.socket.close() - except Exception as e: - pass - os.remove(self.path) + if self.connection: + display.display('closing the connection', log_only=True) + self.close() + except: + pass + finally: + if os.path.exists(self.socket_path): + display.display('removing the local control socket', log_only=True) + os.remove(self.socket_path) + + display.display('shutdown complete', log_only=True) + + def do_EXEC(self, data): + cmd = data.split(b'EXEC: ')[1] + return self.connection.exec_command(cmd) + + def do_PUT(self, data): + (op, src, dst) = shlex.split(to_native(data)) + return self.connection.fetch_file(src, dst) + + def do_FETCH(self, data): + (op, src, dst) = shlex.split(to_native(data)) + return self.connection.put_file(src, dst) + + def do_CONTEXT(self, data): + pc_data = data.split(b'CONTEXT: ', 1)[1] + + if PY3: + pc_data = cPickle.loads(pc_data, encoding='bytes') + else: + pc_data = cPickle.loads(pc_data) + + pc = PlayContext() + pc.deserialize(pc_data) + + try: + self.connection.update_play_context(pc) + except AttributeError: + pass + + return (0, 'ok', '') + + def do_RUN(self, data): + timeout = self.play_context.timeout + while bool(timeout): + if os.path.exists(self.socket_path): + break + time.sleep(1) + timeout -= 1 + return (0, self.socket_path, '') + + +def communicate(sock, data): + send_data(sock, data) + rc = int(recv_data(sock), 10) + stdout = recv_data(sock) + stderr = recv_data(sock) + return (rc, stdout, stderr) def main(): # Need stdin as a byte stream @@ -279,30 +284,32 @@ def main(): pc = PlayContext() pc.deserialize(pc_data) + except Exception as e: # FIXME: better error message/handling/logging sys.stderr.write(traceback.format_exc()) sys.exit("FAIL: %s" % e) ssh = connection_loader.get('ssh', class_only=True) - m = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user) + cp = ssh._create_control_path(pc.remote_addr, pc.connection, pc.remote_user) # create the persistent connection dir if need be and create the paths # which we will be using later - tmp_path = unfrackpath("$HOME/.ansible/pc") + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) makedirs_safe(tmp_path) - lk_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) - sf_path = unfrackpath(m % dict(directory=tmp_path)) + lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) + socket_path = unfrackpath(cp % dict(directory=tmp_path)) # if the socket file doesn't exist, spin up the daemon process - lock_fd = os.open(lk_path, os.O_RDWR|os.O_CREAT, 0o600) + lock_fd = os.open(lock_path, os.O_RDWR|os.O_CREAT, 0o600) fcntl.lockf(lock_fd, fcntl.LOCK_EX) - if not os.path.exists(sf_path): + + if not os.path.exists(socket_path): pid = do_fork() if pid == 0: rc = 0 try: - server = Server(sf_path, pc) + server = Server(socket_path, pc) except AnsibleConnectionFailure as exc: display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True) display.display(str(exc), log_only=True) @@ -318,50 +325,57 @@ def main(): sys.exit(rc) else: display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True) + fcntl.lockf(lock_fd, fcntl.LOCK_UN) os.close(lock_fd) + timeout = pc.timeout + while bool(timeout): + if os.path.exists(socket_path): + display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr) + break + time.sleep(1) + timeout -= 1 + else: + raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr) + # now connect to the daemon process # FIXME: if the socket file existed but the daemonized process was killed, # the connection will timeout here. Need to make this more resilient. - rc = 0 - while rc == 0: + while True: data = stdin.readline() if data == b'': break if data.strip() == b'': continue - sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - attempts = 1 - while True: + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + attempts = C.PERSISTENT_CONNECT_RETRIES + while bool(attempts): try: - sf.connect(sf_path) + sock.connect(socket_path) break except socket.error: - # FIXME: better error handling/logging/message here time.sleep(C.PERSISTENT_CONNECT_INTERVAL) - attempts += 1 - if attempts > C.PERSISTENT_CONNECT_RETRIES: - display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True) - display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True) - sys.stderr.write('failed to connect to control socket') - sys.exit(255) + attempts -= 1 + else: + display.display('number of connection attempts exceeded, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True) + display.display('persistent_connect_interval=%s, persistent_connect_retries=%s' % (C.PERSISTENT_CONNECT_INTERVAL, C.PERSISTENT_CONNECT_RETRIES), pc.remote_addr, pc.remote_user, log_only=True) + sys.stderr.write('failed to connect to control socket') + sys.exit(255) # send the play_context back into the connection so the connection # can handle any privilege escalation activities pc_data = b'CONTEXT: %s' % init_data - send_data(sf, pc_data) - - send_data(sf, data.strip()) + communicate(sock, pc_data) - rc = int(recv_data(sf), 10) - stdout = recv_data(sf) - stderr = recv_data(sf) + rc, stdout, stderr = communicate(sock, data.strip()) sys.stdout.write(to_native(stdout)) sys.stderr.write(to_native(stderr)) - sf.close() + sock.close() break sys.exit(rc) diff --git a/lib/ansible/constants.py b/lib/ansible/constants.py index de4875d6468..fe01ed6ff79 100644 --- a/lib/ansible/constants.py +++ b/lib/ansible/constants.py @@ -394,6 +394,7 @@ PARAMIKO_LOOK_FOR_KEYS = get_config(p, 'paramiko_connection', 'look_for_keys', ' PERSISTENT_CONNECT_TIMEOUT = get_config(p, 'persistent_connection', 'connect_timeout', 'ANSIBLE_PERSISTENT_CONNECT_TIMEOUT', 30, value_type='integer') PERSISTENT_CONNECT_RETRIES = get_config(p, 'persistent_connection', 'connect_retries', 'ANSIBLE_PERSISTENT_CONNECT_RETRIES', 30, value_type='integer') PERSISTENT_CONNECT_INTERVAL = get_config(p, 'persistent_connection', 'connect_interval', 'ANSIBLE_PERSISTENT_CONNECT_INTERVAL', 1, value_type='integer') +PERSISTENT_CONTROL_PATH_DIR = get_config(p, 'persistent_connection', 'control_path_dir', 'ANSIBLE_PERSISTENT_CONTROL_PATH_DIR', u'~/.ansible/pc') # obsolete -- will be formally removed ACCELERATE_PORT = get_config(p, 'accelerate', 'accelerate_port', 'ACCELERATE_PORT', 5099, value_type='integer') diff --git a/lib/ansible/module_utils/connection.py b/lib/ansible/module_utils/connection.py index 785af210ba0..821716c0c1f 100644 --- a/lib/ansible/module_utils/connection.py +++ b/lib/ansible/module_utils/connection.py @@ -29,9 +29,13 @@ import signal import socket import struct +import os +import uuid + +from functools import partial from ansible.module_utils.basic import get_exception -from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils._text import to_bytes, to_native, to_text def send_data(s, data): @@ -75,4 +79,63 @@ def exec_command(module, command): sf.close() - return (rc, to_native(stdout), to_native(stderr)) + return rc, to_native(stdout), to_native(stderr) + + +class Connection: + + def __init__(self, module): + self._module = module + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if name.startswith('_'): + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) + return partial(self.__rpc__, name) + + def __rpc__(self, name, *args, **kwargs): + """Executes the json-rpc and returns the output received + from remote device. + :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0 + :args: Ordered list of params passed as arguments to rpc method + :kwargs: Dict of valid key, value pairs passed as arguments to rpc method + + For usage refer the respective connection plugin docs. + """ + + reqid = str(uuid.uuid4()) + req = {'jsonrpc': '2.0', 'method': name, 'id': reqid} + + params = list(args) or kwargs or None + if params: + req['params'] = params + + if not self._module._socket_path: + self._module.fail_json(msg='provider support not available for this host') + + if not os.path.exists(self._module._socket_path): + self._module.fail_json(msg='provider socket does not exist, is the provider running?') + + try: + data = self._module.jsonify(req) + rc, out, err = exec_command(self._module, data) + + except socket.error: + exc = get_exception() + self._module.fail_json(msg='unable to connect to socket', err=str(exc)) + + try: + response = self._module.from_json(to_text(out, errors='surrogate_then_replace')) + except ValueError as exc: + self._module.fail_json(msg=to_text(exc, errors='surrogate_then_replace')) + + if response['id'] != reqid: + self._module.fail_json(msg='invalid id received') + + if 'error' in response: + msg = response['error'].get('data') or response['error']['message'] + self._module.fail_json(msg=to_text(msg, errors='surrogate_then_replace')) + + return response['result'] diff --git a/lib/ansible/plugins/__init__.py b/lib/ansible/plugins/__init__.py index 2a4110eaae3..d9ab31c0723 100644 --- a/lib/ansible/plugins/__init__.py +++ b/lib/ansible/plugins/__init__.py @@ -550,3 +550,19 @@ vars_loader = PluginLoader( C.DEFAULT_VARS_PLUGIN_PATH, 'vars_plugins', ) + +cliconf_loader = PluginLoader( + 'Cliconf', + 'ansible.plugins.cliconf', + 'cliconf_plugins', + 'cliconf_plugins', + required_base_class='CliconfBase' +) + +netconf_loader = PluginLoader( + 'Netconf', + 'ansible.plugins.netconf', + 'netconf_plugins', + 'netconf_plugins', + required_base_class='NetconfBase' +) diff --git a/lib/ansible/plugins/action/ce.py b/lib/ansible/plugins/action/ce.py index f10ff293d77..e3d1d57dcb8 100644 --- a/lib/ansible/plugins/action/ce.py +++ b/lib/ansible/plugins/action/ce.py @@ -19,17 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.six import iteritems from ansible.module_utils.ce import ce_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -71,26 +67,21 @@ class ActionModule(_ActionModule): ) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if rc != 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(']'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('return') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(']'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('return') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path @@ -100,12 +91,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(ce_argument_spec): diff --git a/lib/ansible/plugins/action/dellos10.py b/lib/ansible/plugins/action/dellos10.py index 171a917beef..b58cf154156 100644 --- a/lib/ansible/plugins/action/dellos10.py +++ b/lib/ansible/plugins/action/dellos10.py @@ -21,17 +21,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.six import iteritems from ansible.module_utils.dellos10 import dellos10_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -67,26 +63,20 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path @@ -97,12 +87,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(dellos10_argument_spec): diff --git a/lib/ansible/plugins/action/dellos6.py b/lib/ansible/plugins/action/dellos6.py index 944e6b0a0b8..85c1d638b37 100644 --- a/lib/ansible/plugins/action/dellos6.py +++ b/lib/ansible/plugins/action/dellos6.py @@ -18,17 +18,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.six import iteritems from ansible.module_utils.dellos6 import dellos6_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -63,26 +59,20 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path @@ -93,12 +83,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(dellos6_argument_spec): diff --git a/lib/ansible/plugins/action/dellos9.py b/lib/ansible/plugins/action/dellos9.py index d5ecdb161ef..795043d1421 100644 --- a/lib/ansible/plugins/action/dellos9.py +++ b/lib/ansible/plugins/action/dellos9.py @@ -21,17 +21,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.six import iteritems from ansible.module_utils.dellos9 import dellos9_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -67,26 +63,20 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path @@ -97,12 +87,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(dellos9_argument_spec): diff --git a/lib/ansible/plugins/action/eos.py b/lib/ansible/plugins/action/eos.py index 965536932b4..3645ad7c47a 100644 --- a/lib/ansible/plugins/action/eos.py +++ b/lib/ansible/plugins/action/eos.py @@ -19,16 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.eos import ARGS_DEFAULT_VALUE, eos_argument_spec from ansible.module_utils.six import iteritems -from ansible.plugins import connection_loader from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath try: from __main__ import display @@ -68,25 +65,20 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path @@ -123,12 +115,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(eos_argument_spec): diff --git a/lib/ansible/plugins/action/ios.py b/lib/ansible/plugins/action/ios.py index fa3fb8bddea..ec6c5480987 100644 --- a/lib/ansible/plugins/action/ios.py +++ b/lib/ansible/plugins/action/ios.py @@ -19,17 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.ios import ios_argument_spec from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -66,25 +62,19 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - if str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + if str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') task_vars['ansible_socket'] = socket_path @@ -95,12 +85,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(ios_argument_spec): diff --git a/lib/ansible/plugins/action/iosxr.py b/lib/ansible/plugins/action/iosxr.py index aa2ae69332e..28d81e209cd 100644 --- a/lib/ansible/plugins/action/iosxr.py +++ b/lib/ansible/plugins/action/iosxr.py @@ -19,17 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.iosxr import iosxr_argument_spec from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes -from ansible.plugins import connection_loader from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath try: from __main__ import display @@ -63,38 +59,26 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if rc != 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(iosxr_argument_spec): diff --git a/lib/ansible/plugins/action/junos.py b/lib/ansible/plugins/action/junos.py index 1d63c4a5ef7..e3972899cdd 100644 --- a/lib/ansible/plugins/action/junos.py +++ b/lib/ansible/plugins/action/junos.py @@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy @@ -28,7 +27,6 @@ from ansible.module_utils.junos import junos_argument_spec from ansible.module_utils.six import iteritems from ansible.plugins import connection_loader, module_loader from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath try: from __main__ import display @@ -75,25 +73,14 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - if not os.path.exists(socket_path): - # start the connection if it isn't started - if pc.connection == 'netconf': - rc, out, err = connection.exec_command('open_session()') - display.vvvv('open_session() returned %s %s %s' % (rc, out, err)) - else: - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - - if rc != 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - - elif pc.connection == 'network_cli': + if pc.connection == 'network_cli': # make sure we are in the right cli context which should be # enable mode and not config module rc, out, err = connection.exec_command('prompt()') @@ -107,15 +94,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - path = unfrackpath("$HOME/.ansible/pc") - # use play_context.connection instea of play_context.port to avoid - # collision if netconf is listening on port 22 - # cp = ssh._create_control_path(play_context.remote_addr, play_context.connection, play_context.remote_user) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(junos_argument_spec): diff --git a/lib/ansible/plugins/action/junos_config.py b/lib/ansible/plugins/action/junos_config.py index b166f23b3ac..aa9fa617ca5 100644 --- a/lib/ansible/plugins/action/junos_config.py +++ b/lib/ansible/plugins/action/junos_config.py @@ -55,7 +55,7 @@ class ActionModule(_ActionModule): # strip out any keys that have two leading and two trailing # underscore characters - for key in result.keys(): + for key in list(result): if PRIVATE_KEYS_RE.match(key): del result[key] diff --git a/lib/ansible/plugins/action/net_base.py b/lib/ansible/plugins/action/net_base.py index 6df0bdc09c2..3bde543b864 100644 --- a/lib/ansible/plugins/action/net_base.py +++ b/lib/ansible/plugins/action/net_base.py @@ -17,16 +17,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action import ActionBase -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes from imp import find_module, load_module @@ -99,25 +95,19 @@ class ActionModule(ActionBase): connection = self._shared_loader_obj.connection_loader.get('persistent', play_context, sys.stdin) - socket_path = self._get_socket_path(play_context) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, play_context.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - if str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + if str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') if self._play_context.become_method == 'enable': self._play_context.become = False @@ -151,13 +141,6 @@ class ActionModule(ActionBase): return implementation_module - # this will be removed once the new connection work is done - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def _load_provider(self, network_os): # we should be able to stream line this a bit by creating a common # provider argument spec in module_utils/network_common.py or another diff --git a/lib/ansible/plugins/action/nxos.py b/lib/ansible/plugins/action/nxos.py index 37a0b9eea1c..7b82f0ea07d 100644 --- a/lib/ansible/plugins/action/nxos.py +++ b/lib/ansible/plugins/action/nxos.py @@ -19,17 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.nxos import nxos_argument_spec from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -73,28 +69,23 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if rc != 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path + else: provider['transport'] = 'nxapi' if provider.get('host') is None: @@ -126,12 +117,6 @@ class ActionModule(_ActionModule): result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(nxos_argument_spec): diff --git a/lib/ansible/plugins/action/sros.py b/lib/ansible/plugins/action/sros.py index d510a773aff..8b1b45b5597 100644 --- a/lib/ansible/plugins/action/sros.py +++ b/lib/ansible/plugins/action/sros.py @@ -19,17 +19,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.sros import sros_argument_spec from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes try: from __main__ import display @@ -64,30 +60,18 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(sros_argument_spec): diff --git a/lib/ansible/plugins/action/vyos.py b/lib/ansible/plugins/action/vyos.py index f8668b89289..2af967dd6ef 100644 --- a/lib/ansible/plugins/action/vyos.py +++ b/lib/ansible/plugins/action/vyos.py @@ -19,16 +19,12 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import os import sys import copy from ansible.plugins.action.normal import ActionModule as _ActionModule -from ansible.utils.path import unfrackpath -from ansible.plugins import connection_loader from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.six import iteritems -from ansible.module_utils._text import to_bytes from ansible.module_utils.vyos import vyos_argument_spec try: @@ -63,38 +59,26 @@ class ActionModule(_ActionModule): display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = self._get_socket_path(pc) + socket_path = connection.run() display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - - if not os.path.exists(socket_path): - # start the connection if it isn't started - rc, out, err = connection.exec_command('open_shell()') - display.vvvv('open_shell() returned %s %s %s' % (rc, out, err)) - if not rc == 0: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell', - 'rc': rc} - else: - # make sure we are in the right cli context which should be - # enable mode and not config module + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + # make sure we are in the right cli context which should be + # enable mode and not config module + rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith('#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith('#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result - def _get_socket_path(self, play_context): - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user) - path = unfrackpath("$HOME/.ansible/pc") - return cp % dict(directory=path) - def load_provider(self): provider = self._task.args.get('provider', {}) for key, value in iteritems(vyos_argument_spec): diff --git a/lib/ansible/plugins/cliconf/__init__.py b/lib/ansible/plugins/cliconf/__init__.py new file mode 100644 index 00000000000..b7e3c2fd8a7 --- /dev/null +++ b/lib/ansible/plugins/cliconf/__init__.py @@ -0,0 +1,188 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import signal + +from abc import ABCMeta, abstractmethod +from functools import wraps + +from ansible.errors import AnsibleError, AnsibleConnectionFailure +from ansible.module_utils.six import with_metaclass + +try: + from __main__ import display +except ImportError: + from ansible.utils.display import Display + display = Display() + + +def enable_mode(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + prompt = self.get_prompt() + if not str(prompt).strip().endswith('#'): + raise AnsibleError('operation requires privilege escalation') + return func(self, *args, **kwargs) + return wrapped + + +class CliconfBase(with_metaclass(ABCMeta, object)): + """ + A base class for implementing cli connections + + .. note:: Unlike most of Ansible, nearly all strings in + :class:`CliconfBase` plugins are byte strings. This is because of + how close to the underlying platform these plugins operate. Remember + to mark literal strings as byte string (``b"string"``) and to use + :func:`~ansible.module_utils._text.to_bytes` and + :func:`~ansible.module_utils._text.to_text` to avoid unexpected + problems. + + List of supported rpc's: + :get_config: Retrieves the specified configuration from the device + :edit_config: Loads the specified commands into the remote device + :get: Execute specified command on remote device + :get_capabilities: Retrieves device information and supported rpc methods + :commit: Load configuration from candidate to running + :discard_changes: Discard changes to candidate datastore + + Note: List of supported rpc's for remote device can be extracted from + output of get_capabilities() + + :returns: Returns output received from remote device as byte string + + Usage: + from ansible.module_utils.connection import Connection + + conn = Connection() + conn.get('show lldp neighbors detail'') + conn.get_config('running') + conn.edit_config(['hostname test', 'netconf ssh']) + """ + + def __init__(self, connection): + self._connection = connection + + def _alarm_handler(self, signum, frame): + raise AnsibleConnectionFailure('timeout waiting for command to complete') + + def send_command(self, command, prompt=None, answer=None, sendonly=False): + """Executes a cli command and returns the results + This method will execute the CLI command on the connection and return + the results to the caller. The command output will be returned as a + string + """ + timeout = self._connection._play_context.timeout or 30 + signal.signal(signal.SIGALRM, self._alarm_handler) + signal.alarm(timeout) + display.display("command: %s" % command, log_only=True) + resp = self._connection.send(command, prompt, answer, sendonly) + signal.alarm(0) + return resp + + def get_prompt(self): + """Returns the current prompt from the device""" + return self._connection._matched_prompt + + def get_base_rpc(self): + """Returns list of base rpc method supported by remote device""" + return ['get_config', 'edit_config', 'get_capabilities', 'get'] + + @abstractmethod + def get_config(self, source='running', format='text'): + """Retrieves the specified configuration from the device + This method will retrieve the configuration specified by source and + return it to the caller as a string. Subsequent calls to this method + will retrieve a new configuration from the device + :args: + arg[0] source: Datastore from which configuration should be retrieved eg: running/candidate/startup. (optional) + default is running. + arg[1] format: Output format in which configuration is retrieved + Note: Specified datastore should be supported by remote device. + :kwargs: + Keywords supported + :command: the command string to execute + :source: Datastore from which configuration should be retrieved + :format: Output format in which configuration is retrieved + :returns: Returns output received from remote device as byte string + """ + pass + + @abstractmethod + def edit_config(self, commands): + """Loads the specified commands into the remote device + This method will load the commands into the remote device. This + method will make sure the device is in the proper context before + send the commands (eg config mode) + :args: + arg[0] command: List of configuration commands + :kwargs: + Keywords supported + :command: the command string to execute + :returns: Returns output received from remote device as byte string + """ + pass + + @abstractmethod + def get(self, *args, **kwargs): + """Execute specified command on remote device + This method will retrieve the specified data and + return it to the caller as a string. + :args: + arg[0] command: command in string format to be executed on remote device + arg[1] prompt: the expected prompt generated by executing command. + This can be a string or a list of strings (optional) + arg[2] answer: the string to respond to the prompt with (optional) + arg[3] sendonly: bool to disable waiting for response, default is false (optional) + :kwargs: + :command: the command string to execute + :prompt: the expected prompt generated by executing command. + This can be a string or a list of strings + :answer: the string to respond to the prompt with + :sendonly: bool to disable waiting for response + :returns: Returns output received from remote device as byte string + """ + pass + + @abstractmethod + def get_capabilities(self): + """Retrieves device information and supported + rpc methods by device platform and return result + as a string + :returns: Returns output received from remote device as byte string + """ + pass + + def commit(self, comment=None): + """Commit configuration changes""" + return self._connection.method_not_found("commit is not supported by network_os %s" % self._play_context.network_os) + + def discard_changes(self): + "Discard changes in candidate datastore" + return self._connection.method_not_found("discard_changes is not supported by network_os %s" % self._play_context.network_os) + + def put_file(self, source, destination): + """Copies file over scp to remote device""" + pass + + def fetch_file(self, source, destination): + """Fetch file over scp from remote device""" + pass diff --git a/lib/ansible/plugins/cliconf/eos.py b/lib/ansible/plugins/cliconf/eos.py new file mode 100644 index 00000000000..400d55fbbb3 --- /dev/null +++ b/lib/ansible/plugins/cliconf/eos.py @@ -0,0 +1,73 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +from itertools import chain + +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'eos' + reply = self.get(b'show version | json') + data = json.loads(reply) + + device_info['network_os_version'] = data['version'] + device_info['network_os_model'] = data['modelName'] + + reply = self.get(b'show hostname | json') + data = json.loads(reply) + + device_info['network_os_hostname'] = data['hostname'] + + return device_info + + @enable_mode + def get_config(self, source='running', format='text'): + lookup = {'running': 'running-config', 'startup': 'startup-config'} + if source not in lookup: + return self.invalid_params("fetching configuration from %s is not supported" % source) + if format == 'text': + cmd = b'show %s' % lookup[source] + else: + cmd = b'show %s | %s' % (lookup[source], format) + return self.send_command(cmd) + + @enable_mode + def edit_config(self, command): + for cmd in chain([b'configure'], to_list(command), [b'end']): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/cliconf/ios.py b/lib/ansible/plugins/cliconf/ios.py new file mode 100644 index 00000000000..68a3ce413d1 --- /dev/null +++ b/lib/ansible/plugins/cliconf/ios.py @@ -0,0 +1,78 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import json + +from itertools import chain + +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'ios' + reply = self.get(b'show version') + data = to_text(reply, errors='surrogate_or_strict').strip() + + match = re.search(r'Version (\S+),', data) + if match: + device_info['network_os_version'] = match.group(1) + + match = re.search(r'^Cisco (.+) \(revision', data, re.M) + if match: + device_info['network_os_model'] = match.group(1) + + match = re.search(r'^(.+) uptime', data, re.M) + if match: + device_info['network_os_hostname'] = match.group(1) + + return device_info + + @enable_mode + def get_config(self, source='running'): + if source not in ('running', 'startup'): + return self.invalid_params("fetching configuration from %s is not supported" % source) + if source == 'running': + cmd = b'show running-config all' + else: + cmd = b'show startup-config' + return self.send_command(cmd) + + @enable_mode + def edit_config(self, command): + for cmd in chain([b'configure terminal'], to_list(command), [b'end']): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/cliconf/iosxr.py b/lib/ansible/plugins/cliconf/iosxr.py new file mode 100644 index 00000000000..fe890c22cc9 --- /dev/null +++ b/lib/ansible/plugins/cliconf/iosxr.py @@ -0,0 +1,87 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import json + +from itertools import chain + +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase + + +class Cliconf(CliconfBase): + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'iosxr' + reply = self.get(b'show version brief') + data = to_text(reply, errors='surrogate_or_strict').strip() + + match = re.search(r'Version (\S+)$', data, re.M) + if match: + device_info['network_os_version'] = match.group(1) + + match = re.search(r'image file is "(.+)"', data) + if match: + device_info['network_os_image'] = match.group(1) + + match = re.search(r'^Cisco (.+) \(revision', data, re.M) + if match: + device_info['network_os_model'] = match.group(1) + + match = re.search(r'^(.+) uptime', data, re.M) + if match: + device_info['network_os_hostname'] = match.group(1) + + return device_info + + def get_config(self, source='running'): + lookup = {'running': 'running-config'} + if source not in lookup: + return self.invalid_params("fetching configuration from %s is not supported" % source) + return self.send_command(to_bytes(b'show %s' % lookup[source], errors='surrogate_or_strict')) + + def edit_config(self, command): + for cmd in chain([b'configure'], to_list(command), [b'end']): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def commit(self, comment=None): + if comment: + command = b'commit comment {0}'.format(comment) + else: + command = b'commit' + self.send_command(command) + + def discard_changes(self): + self.send_command(b'abort') + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + ['commit', 'discard_changes'] + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/cliconf/junos.py b/lib/ansible/plugins/cliconf/junos.py new file mode 100644 index 00000000000..a3cbf4c2dc2 --- /dev/null +++ b/lib/ansible/plugins/cliconf/junos.py @@ -0,0 +1,87 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import json + +from itertools import chain +from xml.etree.ElementTree import fromstring + +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + + def get_text(self, ele, tag): + try: + return to_text(ele.find(tag).text, errors='surrogate_then_replace').strip() + except AttributeError: + pass + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'junos' + reply = self.get(b'show version | display xml') + data = fromstring(to_text(reply, errors='surrogate_then_replace').strip()) + + sw_info = data.find('.//software-information') + + device_info['network_os_version'] = self.get_text(sw_info, 'junos-version') + device_info['network_os_hostname'] = self.get_text(sw_info, 'host-name') + device_info['network_os_model'] = self.get_text(sw_info, 'product-model') + + return device_info + + def get_config(self, source='running', format='text'): + if source != 'running': + return self.invalid_params("fetching configuration from %s is not supported" % source) + if format == 'text': + cmd = b'show configuration' + else: + cmd = b'show configuration | display %s' % format + return self.send_command(to_bytes(cmd), errors='surrogate_or_strict') + + def edit_config(self, command): + for cmd in chain([b'configure'], to_list(command)): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def commit(self, comment=None): + if comment: + command = b'commit comment {0}'.format(comment) + else: + command = b'commit' + self.send_command(command) + + def discard_changes(self): + self.send_command(b'rollback') + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + ['commit', 'discard_changes'] + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/cliconf/nxos.py b/lib/ansible/plugins/cliconf/nxos.py new file mode 100644 index 00000000000..2f58aec7b5b --- /dev/null +++ b/lib/ansible/plugins/cliconf/nxos.py @@ -0,0 +1,62 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +from itertools import chain + +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase + + +class Cliconf(CliconfBase): + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'nxos' + reply = self.get(b'show version | json') + data = json.loads(reply) + + device_info['network_os_version'] = data['sys_ver_str'] + device_info['network_os_model'] = data['chassis_id'] + device_info['network_os_hostname'] = data['host_name'] + device_info['network_os_image'] = data['isan_file_name'] + + return device_info + + def get_config(self, source='running'): + lookup = {'running': 'running-config', 'startup': 'startup-config'} + return self.send_command(b'show %s' % lookup[source]) + + def edit_config(self, command): + for cmd in chain([b'configure'], to_list(command), [b'end']): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/cliconf/vyos.py b/lib/ansible/plugins/cliconf/vyos.py new file mode 100644 index 00000000000..c8e5d819864 --- /dev/null +++ b/lib/ansible/plugins/cliconf/vyos.py @@ -0,0 +1,79 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import json + +from itertools import chain + +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.network_common import to_list +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'vyos' + reply = self.get(b'show version') + data = to_text(reply, errors='surrogate_or_strict').strip() + + match = re.search(r'Version:\s*(\S+)', data) + if match: + device_info['network_os_version'] = match.group(1) + + match = re.search(r'HW model:\s*(\S+)', data) + if match: + device_info['network_os_model'] = match.group(1) + + reply = self.get(b'show host name') + device_info['network_os_hostname'] = to_text(reply, errors='surrogate_or_strict').strip() + + return device_info + + def get_config(self): + return self.send_command(b'show configuration all') + + def edit_config(self, command): + for cmd in chain([b'configure'], to_list(command)): + self.send_command(cmd) + + def get(self, *args, **kwargs): + return self.send_command(*args, **kwargs) + + def commit(self, comment=None): + if comment: + command = b'commit comment {0}'.format(comment) + else: + command = b'commit' + self.send_command(command) + + def discard_changes(self, *args, **kwargs): + self.send_command(b'discard') + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + ['commit', 'discard_changes'] + result['network_api'] = 'cliconf' + result['device_info'] = self.get_device_info() + return json.dumps(result) diff --git a/lib/ansible/plugins/connection/netconf.py b/lib/ansible/plugins/connection/netconf.py index 8143bea696a..0c9f4259c87 100644 --- a/lib/ansible/plugins/connection/netconf.py +++ b/lib/ansible/plugins/connection/netconf.py @@ -20,10 +20,14 @@ __metaclass__ = type import os import logging +import json from ansible import constants as C from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.plugins import netconf_loader from ansible.plugins.connection import ConnectionBase, ensure_connect +from ansible.utils.jsonrpc import Rpc try: from ncclient import manager @@ -42,8 +46,8 @@ except ImportError: logging.getLogger('ncclient').setLevel(logging.INFO) -class Connection(ConnectionBase): - ''' NetConf connections ''' +class Connection(Rpc, ConnectionBase): + """NetConf connections""" transport = 'netconf' has_pipelining = False @@ -90,12 +94,20 @@ class Connection(ConnectionBase): raise AnsibleConnectionFailure(str(exc)) if not self._manager.connected: - return (1, '', 'not connected') + return 1, b'', b'not connected' display.display('ncclient manager object created successfully', log_only=True) self._connected = True - return (0, self._manager.session_id, '') + + self._netconf = netconf_loader.get(self._network_os, self) + if self._netconf: + self._rpc.add(self._netconf) + display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True) + else: + display.display('unable to load netconf for network_os %s' % self._network_os) + + return 0, to_bytes(self._manager.session_id, errors='surrogate_or_strict'), b'' def close(self): if self._manager: @@ -106,20 +118,37 @@ class Connection(ConnectionBase): @ensure_connect def exec_command(self, request): """Sends the request to the node and returns the reply + The method accepts two forms of request. The first form is as a byte + string that represents xml string be send over netconf session. + The second form is a json-rpc (2.0) byte string. """ - if request == 'open_session()': - return (0, 'ok', '') + try: + obj = json.loads(to_text(request, errors='surrogate_or_strict')) + + if 'jsonrpc' in obj: + if self._netconf: + out = self._exec_rpc(obj) + else: + out = self.internal_error("netconf plugin is not supported for network_os %s" % self._play_context.network_os) + return 0, to_bytes(out, errors='surrogate_or_strict'), b'' + else: + err = self.invalid_request(obj) + return 1, b'', to_bytes(err, errors='surrogate_or_strict') + + except (ValueError, TypeError): + # to_ele operates on native strings + request = to_native(request, errors='surrogate_or_strict') req = to_ele(request) if req is None: - return (1, '', 'unable to parse request') + return 1, b'', b'unable to parse request' try: reply = self._manager.rpc(req) except RPCError as exc: - return (1, '', to_xml(exc.xml)) + return 1, b'', to_bytes(to_xml(exc.xml), errors='surrogate_or_strict') - return (0, reply.data_xml, '') + return 0, to_bytes(reply.data_xml, errors='surrogate_or_strict'), b'' def put_file(self, in_path, out_path): """Transfer a file from local to remote""" diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index dd2b1800135..6814272fc2f 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -24,15 +24,17 @@ import re import signal import socket import traceback + from collections import Sequence from ansible import constants as C from ansible.errors import AnsibleConnectionFailure from ansible.module_utils.six import BytesIO, binary_type from ansible.module_utils._text import to_bytes, to_text +from ansible.plugins import cliconf_loader from ansible.plugins import terminal_loader -from ansible.plugins.connection import ensure_connect from ansible.plugins.connection.paramiko_ssh import Connection as _Connection +from ansible.utils.jsonrpc import Rpc try: from __main__ import display @@ -41,7 +43,7 @@ except ImportError: display = Display() -class Connection(_Connection): +class Connection(Rpc, _Connection): ''' CLI (shell) SSH connections on Paramiko ''' transport = 'network_cli' @@ -51,11 +53,13 @@ class Connection(_Connection): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) self._terminal = None + self._cliconf = None self._shell = None self._matched_prompt = None self._matched_pattern = None self._last_response = None self._history = list() + self._play_context = play_context if play_context.verbosity > 3: logging.getLogger('paramiko').setLevel(logging.DEBUG) @@ -84,6 +88,9 @@ class Connection(_Connection): display.display('ssh connection done, setting terminal', log_only=True) + self._shell = self.ssh.invoke_shell() + self._shell.settimeout(self._play_context.timeout) + network_os = self._play_context.network_os if not network_os: raise AnsibleConnectionFailure( @@ -95,46 +102,45 @@ class Connection(_Connection): if not self._terminal: raise AnsibleConnectionFailure('network os %s is not supported' % network_os) - self._connected = True - display.display('ssh connection has completed successfully', log_only=True) + display.display('loaded terminal plugin for network_os %s' % network_os, log_only=True) - @ensure_connect - def open_shell(self): - display.display('attempting to open shell to device', log_only=True) - self._shell = self.ssh.invoke_shell() - self._shell.settimeout(self._play_context.timeout) + self._cliconf = cliconf_loader.get(network_os, self) + if self._cliconf: + self._rpc.add(self._cliconf) + display.display('loaded cliconf plugin for network_os %s' % network_os, log_only=True) + else: + display.display('unable to load cliconf for network_os %s' % network_os) self.receive() - if self._shell: - self._terminal.on_open_shell() + display.display('firing event: on_open_shell()', log_only=True) + self._terminal.on_open_shell() if getattr(self._play_context, 'become', None): + display.display('firing event: on_authorize', log_only=True) auth_pass = self._play_context.become_pass self._terminal.on_authorize(passwd=auth_pass) - display.display('shell successfully opened', log_only=True) - return (0, b'ok', b'') + self._connected = True + display.display('ssh connection has completed successfully', log_only=True) def close(self): - display.display('closing connection', log_only=True) - self.close_shell() - super(Connection, self).close() - self._connected = False - - def close_shell(self): - """Closes the vty shell if the device supports multiplexing""" - display.display('closing shell on device', log_only=True) + """Close the active connection to the device + """ + display.display("closing ssh connection to device", log_only=True) if self._shell: + display.display("firing event: on_close_shell()", log_only=True) self._terminal.on_close_shell() - - if self._shell: self._shell.close() self._shell = None + display.display("cli session is now closed", log_only=True) - return (0, b'ok', b'') + super(Connection, self).close() - def receive(self, obj=None): + self._connected = False + display.display("ssh connection has been closed successfully", log_only=True) + + def receive(self, command=None, prompts=None, answer=None): """Handles receiving of output from command""" recv = BytesIO() handled = False @@ -150,23 +156,22 @@ class Connection(_Connection): window = self._strip(recv.read()) - if obj and (obj.get('prompt') and not handled): - handled = self._handle_prompt(window, obj['prompt'], obj['answer']) + if prompts and not handled: + handled = self._handle_prompt(window, prompts, answer) if self._find_prompt(window): self._last_response = recv.getvalue() resp = self._strip(self._last_response) - return self._sanitize(resp, obj) + return self._sanitize(resp, command) - def send(self, obj): + def send(self, command, prompts=None, answer=None, send_only=False): """Sends the command to the device in the opened shell""" try: - command = obj['command'] self._history.append(command) self._shell.sendall(b'%s\r' % command) - if obj.get('sendonly'): + if send_only: return - return self.receive(obj) + return self.receive(command, prompts, answer) except (socket.timeout, AttributeError): display.display(traceback.format_exc(), log_only=True) raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) @@ -195,10 +200,9 @@ class Connection(_Connection): return True return False - def _sanitize(self, resp, obj=None): + def _sanitize(self, resp, command=None): """Removes elements from the response before returning to the caller""" cleaned = [] - command = obj.get('command') if obj else None for line in resp.splitlines(): if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line: continue @@ -243,10 +247,10 @@ class Connection(_Connection): def exec_command(self, cmd): """Executes the cmd on in the shell and returns the output - The method accepts two forms of cmd. The first form is as a byte + The method accepts three forms of cmd. The first form is as a byte string that represents the command to be executed in the shell. The second form is as a utf8 JSON byte string with additional keywords. - + The third form is a json-rpc (2.0) Keywords supported for cmd: :command: the command string to execute :prompt: the expected prompt generated by executing command. @@ -275,27 +279,23 @@ class Connection(_Connection): else: # Prompt was a Sequence of strings. Make sure they're byte strings obj['prompt'] = [to_bytes(p, errors='surrogate_or_strict') for p in obj['prompt'] if p is not None] - if obj['command'] == b'close_shell()': - return self.close_shell() - elif obj['command'] == b'open_shell()': - return self.open_shell() - elif obj['command'] == b'prompt()': - return (0, self._matched_prompt, b'') - try: - if self._shell is None: - self.open_shell() - except AnsibleConnectionFailure as exc: - # FIXME: Feels like we should raise this rather than return it - return (1, b'', to_bytes(exc)) + if 'jsonrpc' in obj: + if self._cliconf: + out = self._exec_rpc(obj) + else: + out = self.internal_error("cliconf is not supported for network_os %s" % self._play_context.network_os) + return 0, to_bytes(out, errors='surrogate_or_strict'), b'' + + if obj['command'] == b'prompt()': + return 0, self._matched_prompt, b'' try: if not signal.getsignal(signal.SIGALRM): signal.signal(signal.SIGALRM, self.alarm_handler) signal.alarm(self._play_context.timeout) - out = self.send(obj) + out = self.send(obj['command'], obj.get('prompt'), obj.get('answer'), obj.get('sendonly')) signal.alarm(0) - return (0, out, b'') + return 0, out, b'' except (AnsibleConnectionFailure, ValueError) as exc: - # FIXME: Feels like we should raise this rather than return it - return (1, b'', to_bytes(exc)) + return 1, b'', to_bytes(exc) diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py index fc210a9766d..6b3adff3c17 100644 --- a/lib/ansible/plugins/connection/persistent.py +++ b/lib/ansible/plugins/connection/persistent.py @@ -1,4 +1,4 @@ -# (c) 2016 Red Hat Inc. +# (c) 2017 Red Hat Inc. # # This file is part of Ansible # @@ -41,7 +41,6 @@ class Connection(ConnectionBase): has_pipelining = False def _connect(self): - self._connected = True return self @@ -83,3 +82,7 @@ class Connection(ConnectionBase): def close(self): self._connected = False + + def run(self): + rc, out, err = self._do_it('RUN:') + return out diff --git a/lib/ansible/plugins/netconf/__init__.py b/lib/ansible/plugins/netconf/__init__.py new file mode 100644 index 00000000000..ce7680f44a0 --- /dev/null +++ b/lib/ansible/plugins/netconf/__init__.py @@ -0,0 +1,189 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from abc import ABCMeta, abstractmethod +from functools import wraps + +from ansible.module_utils.six import with_metaclass + + +def ensure_connected(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + if not self._connection._connected: + self._connection._connect() + return func(self, *args, **kwargs) + return wrapped + + +class NetconfBase(with_metaclass(ABCMeta, object)): + """ + A base class for implementing Netconf connections + + .. note:: Unlike most of Ansible, nearly all strings in + :class:`TerminalBase` plugins are byte strings. This is because of + how close to the underlying platform these plugins operate. Remember + to mark literal strings as byte string (``b"string"``) and to use + :func:`~ansible.module_utils._text.to_bytes` and + :func:`~ansible.module_utils._text.to_text` to avoid unexpected + problems. + + List of supported rpc's: + :get_config: Retrieves the specified configuration from the device + :edit_config: Loads the specified commands into the remote device + :get: Execute specified command on remote device + :get_capabilities: Retrieves device information and supported rpc methods + :commit: Load configuration from candidate to running + :discard_changes: Discard changes to candidate datastore + :validate: Validate the contents of the specified configuration. + :lock: Allows the client to lock the configuration system of a device. + :unlock: Release a configuration lock, previously obtained with the lock operation. + :copy_config: create or replace an entire configuration datastore with the contents of another complete + configuration datastore. + For JUNOS: + :execute_rpc: RPC to be execute on remote device + :load_configuration: Loads given configuration on device + + Note: rpc support depends on the capabilites of remote device. + + :returns: Returns output received from remote device as byte string + Note: the 'result' or 'error' from response should to be converted to object + of ElementTree using 'fromstring' to parse output as xml doc + + 'get_capabilities()' returns 'result' as a json string. + + Usage: + from ansible.module_utils.connection import Connection + + conn = Connection() + data = conn.execute_rpc(rpc) + reply = fromstring(reply) + + data = conn.get_capabilities() + json.loads(data) + + conn.load_configuration(config=[''set system ntp server 1.1.1.1''], action='set', format='text') + """ + + def __init__(self, connection): + self._connection = connection + self.m = self._connection._manager + + @ensure_connected + def get_config(self, *args, **kwargs): + """Retrieve all or part of a specified configuration. + :source: name of the configuration datastore being queried + :filter: specifies the portion of the configuration to retrieve + (by default entire configuration is retrieved)""" + return self.m.get_config(*args, **kwargs).data_xml + + @ensure_connected + def get(self, *args, **kwargs): + """Retrieve running configuration and device state information. + *filter* specifies the portion of the configuration to retrieve + (by default entire configuration is retrieved) + """ + return self.m.get(*args, **kwargs).data_xml + + @ensure_connected + def edit_config(self, *args, **kwargs): + """Loads all or part of the specified *config* to the *target* configuration datastore. + + :target: is the name of the configuration datastore being edited + :config: is the configuration, which must be rooted in the `config` element. + It can be specified either as a string or an :class:`~xml.etree.ElementTree.Element`. + :default_operation: if specified must be one of { `"merge"`, `"replace"`, or `"none"` } + :test_option: if specified must be one of { `"test_then_set"`, `"set"` } + :error_option: if specified must be one of { `"stop-on-error"`, `"continue-on-error"`, `"rollback-on-error"` } + The `"rollback-on-error"` *error_option* depends on the `:rollback-on-error` capability. + """ + return self.m.get_config(*args, **kwargs).data_xml + + @ensure_connected + def validate(self, *args, **kwargs): + """Validate the contents of the specified configuration. + :source: is the name of the configuration datastore being validated or `config` + element containing the configuration subtree to be validated + """ + return self.m.validate(*args, **kwargs).data_xml + + @ensure_connected + def copy_config(self, *args, **kwargs): + """Create or replace an entire configuration datastore with the contents of another complete + configuration datastore. + :source: is the name of the configuration datastore to use as the source of the + copy operation or `config` element containing the configuration subtree to copy + :target: is the name of the configuration datastore to use as the destination of the copy operation""" + return self.m.copy_config(*args, **kwargs).data_xml + + @ensure_connected + def lock(self, *args, **kwargs): + """Allows the client to lock the configuration system of a device. + *target* is the name of the configuration datastore to lock + """ + return self.m.lock(*args, **kwargs).data_xml + + @ensure_connected + def unlock(self, *args, **kwargs): + """Release a configuration lock, previously obtained with the lock operation. + :target: is the name of the configuration datastore to unlock + """ + return self.m.lock(*args, **kwargs).data_xml + + @ensure_connected + def discard_changes(self, *args, **kwargs): + """Revert the candidate configuration to the currently running configuration. + Any uncommitted changes are discarded.""" + return self.m.discard_changes(*args, **kwargs).data_xml + + @ensure_connected + def commit(self, *args, **kwargs): + """Commit the candidate configuration as the device's new current configuration. + Depends on the `:candidate` capability. + A confirmed commit (i.e. if *confirmed* is `True`) is reverted if there is no + followup commit within the *timeout* interval. If no timeout is specified the + confirm timeout defaults to 600 seconds (10 minutes). + A confirming commit may have the *confirmed* parameter but this is not required. + Depends on the `:confirmed-commit` capability. + :confirmed: whether this is a confirmed commit + :timeout: specifies the confirm timeout in seconds + """ + return self.m.commit(*args, **kwargs).data_xml + + @abstractmethod + def get_capabilities(self, commands): + """Retrieves device information and supported + rpc methods by device platform and return result + as a string + """ + pass + + def get_base_rpc(self): + """Returns list of base rpc method supported by remote device""" + return ['get_config', 'edit_config', 'get_capabilities', 'get'] + + def put_file(self, source, destination): + """Copies file over scp to remote device""" + pass + + def fetch_file(self, source, destination): + """Fetch file over scp from remote device""" + pass diff --git a/lib/ansible/plugins/netconf/junos.py b/lib/ansible/plugins/netconf/junos.py new file mode 100644 index 00000000000..54c020c6cc9 --- /dev/null +++ b/lib/ansible/plugins/netconf/junos.py @@ -0,0 +1,79 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +from xml.etree.ElementTree import fromstring + +from ansible.module_utils._text import to_bytes, to_text +from ansible.plugins.netconf import NetconfBase +from ansible.plugins.netconf import ensure_connected + +from ncclient.xml_ import new_ele + + +class Netconf(NetconfBase): + + def get_text(self, ele, tag): + try: + return to_text(ele.find(tag).text, errors='surrogate_then_replace').strip() + except AttributeError: + pass + + def get_device_info(self): + device_info = {} + + device_info['network_os'] = 'junos' + data = self.execute_rpc('get-software-information') + reply = fromstring(data) + sw_info = reply.find('.//software-information') + + device_info['network_os_version'] = self.get_text(sw_info, 'junos-version') + device_info['network_os_hostname'] = self.get_text(sw_info, 'host-name') + device_info['network_os_model'] = self.get_text(sw_info, 'product-model') + + return device_info + + @ensure_connected + def execute_rpc(self, rpc): + """RPC to be execute on remote device + :rpc: Name of rpc in string format""" + name = new_ele(rpc) + return self.m.rpc(name).data_xml + + @ensure_connected + def load_configuration(self, *args, **kwargs): + """Loads given configuration on device + :format: Format of configuration (xml, text, set) + :action: Action to be performed (merge, replace, override, update) + :target: is the name of the configuration datastore being edited + :config: is the configuration in string format.""" + return self.m.load_configuration(*args, **kwargs).data_xml + + def get_capabilities(self): + result = {} + result['rpc'] = self.get_base_rpc() + ['commit', 'discard_changes', 'validate', 'lock', 'unlock', 'copy_copy'] + result['network_api'] = 'netconf' + result['device_info'] = self.get_device_info() + result['server_capabilities'] = [c for c in self.m.server_capabilities] + result['client_capabilities'] = [c for c in self.m.client_capabilities] + result['session_id'] = self.m.session_id + return json.dumps(result) diff --git a/lib/ansible/utils/jsonrpc.py b/lib/ansible/utils/jsonrpc.py new file mode 100644 index 00000000000..70c35427a4f --- /dev/null +++ b/lib/ansible/utils/jsonrpc.py @@ -0,0 +1,115 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json +import traceback + +from ansible.module_utils._text import to_text + +try: + from __main__ import display +except ImportError: + from ansible.utils.display import Display + display = Display() + + +class Rpc: + + def __init__(self, *args, **kwargs): + self._rpc = set() + super(Rpc, self).__init__(*args, **kwargs) + + def _exec_rpc(self, request): + method = request.get('method') + + if method.startswith('rpc.') or method.startswith('_'): + error = self.invalid_request() + return json.dumps(error) + + params = request.get('params') + setattr(self, '_identifier', request.get('id')) + args = [] + kwargs = {} + + if all((params, isinstance(params, list))): + args = params + elif all((params, isinstance(params, dict))): + kwargs = params + + rpc_method = None + for obj in self._rpc: + rpc_method = getattr(obj, method, None) + if rpc_method: + break + + if not rpc_method: + error = self.method_not_found() + response = json.dumps(error) + else: + try: + result = rpc_method(*args, **kwargs) + display.display(" -- result -- %s" % result, log_only=True) + except Exception as exc: + display.display(traceback.format_exc(), log_only=True) + error = self.internal_error(data=to_text(exc, errors='surrogate_then_replace')) + response = json.dumps(error) + else: + if isinstance(result, dict) and 'jsonrpc' in result: + response = result + else: + response = self.response(result) + + response = json.dumps(response) + + display.display(" -- response -- %s" % response, log_only=True) + delattr(self, '_identifier') + return response + + def header(self): + return {'jsonrpc': '2.0', 'id': self._identifier} + + def response(self, result=None): + response = self.header() + response['result'] = result or 'ok' + return response + + def error(self, code, message, data=None): + response = self.header() + error = {'code': code, 'message': message} + if data: + error['data'] = data + response['error'] = error + return response + + # json-rpc standard errors (-32768 .. -32000) + def parse_error(self, data=None): + return self.error(-32700, 'Parse error', data) + + def method_not_found(self, data=None): + return self.error(-32601, 'Method not found', data) + + def invalid_request(self, data=None): + return self.error(-32600, 'Invalid request', data) + + def invalid_params(self, data=None): + return self.error(-32602, 'Invalid params', data) + + def internal_error(self, data=None): + return self.error(-32603, 'Internal error', data) diff --git a/test/units/plugins/connection/test_netconf.py b/test/units/plugins/connection/test_netconf.py index 64d7554ea22..4e924bd92f6 100644 --- a/test/units/plugins/connection/test_netconf.py +++ b/test/units/plugins/connection/test_netconf.py @@ -76,8 +76,8 @@ class TestNetconfConnectionClass(unittest.TestCase): rc, out, err = conn._connect() self.assertEqual(0, rc) - self.assertEqual('123456789', out) - self.assertEqual('', err) + self.assertEqual(b'123456789', out) + self.assertEqual(b'', err) self.assertTrue(conn._connected) def test_netconf_exec_command(self): @@ -101,8 +101,8 @@ class TestNetconfConnectionClass(unittest.TestCase): netconf.to_ele.assert_called_with('') self.assertEqual(0, rc) - self.assertEqual('', out) - self.assertEqual('', err) + self.assertEqual(b'', out) + self.assertEqual(b'', err) def test_netconf_exec_command_invalid_request(self): pc = PlayContext() @@ -116,5 +116,5 @@ class TestNetconfConnectionClass(unittest.TestCase): rc, out, err = conn.exec_command('test string') self.assertEqual(1, rc) - self.assertEqual('', out) - self.assertEqual('unable to parse request', err) + self.assertEqual(b'', out) + self.assertEqual(b'unable to parse request', err) diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index 0f5a8efe1f1..cb93aceb9a0 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -35,34 +35,38 @@ from ansible.plugins.connection import network_cli class TestConnectionClass(unittest.TestCase): - @patch("ansible.plugins.connection.network_cli.terminal_loader") @patch("ansible.plugins.connection.network_cli._Connection._connect") - def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): + def test_network_cli__connect_error(self, mocked_super): pc = PlayContext() new_stdin = StringIO() conn = network_cli.Connection(pc, new_stdin) - conn.ssh = None - + conn.ssh = MagicMock() + conn.receive = MagicMock() + conn._terminal = MagicMock() + pc.network_os = None self.assertRaises(AnsibleConnectionFailure, conn._connect) - mocked_terminal_loader.reset_mock() - mocked_terminal_loader.get.return_value = None + @patch("ansible.plugins.connection.network_cli._Connection._connect") + def test_network_cli__invalid_os(self, mocked_super): + pc = PlayContext() + new_stdin = StringIO() - pc.network_os = 'invalid' + conn = network_cli.Connection(pc, new_stdin) + conn.ssh = MagicMock() + conn.receive = MagicMock() + conn._terminal = MagicMock() + pc.network_os = None self.assertRaises(AnsibleConnectionFailure, conn._connect) - self.assertFalse(mocked_terminal_loader.all.called) - - mocked_terminal_loader.reset_mock() - mocked_terminal_loader.get.return_value = 'valid' - - conn._connect() - self.assertEqual(conn._terminal, 'valid') - def test_network_cli_open_shell(self): + @patch("ansible.plugins.connection.network_cli.terminal_loader") + @patch("ansible.plugins.connection.network_cli._Connection._connect") + def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): pc = PlayContext() new_stdin = StringIO() + conn = network_cli.Connection(pc, new_stdin) + pc.network_os = 'ios' conn.ssh = MagicMock() conn.receive = MagicMock() @@ -70,26 +74,19 @@ class TestConnectionClass(unittest.TestCase): mock_terminal = MagicMock() conn._terminal = mock_terminal - mock__connect = MagicMock() - conn._connect = mock__connect - - conn.open_shell() - - self.assertTrue(mock__connect.called) - self.assertTrue(mock_terminal.on_open_shell.called) - self.assertFalse(mock_terminal.on_authorize.called) - - mock_terminal.reset_mock() + conn._connect() + self.assertTrue(conn._terminal.on_open_shell.called) + self.assertFalse(conn._terminal.on_authorize.called) conn._play_context.become = True conn._play_context.become_pass = 'password' - conn.open_shell() + conn._connect() - self.assertTrue(mock__connect.called) - mock_terminal.on_authorize.assert_called_with(passwd='password') + conn._terminal.on_authorize.assert_called_with(passwd='password') - def test_network_cli_close_shell(self): + @patch("ansible.plugins.connection.network_cli._Connection.close") + def test_network_cli_close(self, mocked_super): pc = PlayContext() new_stdin = StringIO() conn = network_cli.Connection(pc, new_stdin) @@ -97,51 +94,43 @@ class TestConnectionClass(unittest.TestCase): terminal = MagicMock(supports_multiplexing=False) conn._terminal = terminal - conn.close_shell() + conn.close() conn._shell = MagicMock() - conn.close_shell() + conn.close() self.assertTrue(terminal.on_close_shell.called) terminal.supports_multiplexing = True - conn.close_shell() + conn.close() self.assertIsNone(conn._shell) - def test_network_cli_exec_command(self): + @patch("ansible.plugins.connection.network_cli._Connection._connect") + def test_network_cli_exec_command(self, mocked_super): pc = PlayContext() new_stdin = StringIO() conn = network_cli.Connection(pc, new_stdin) - mock_open_shell = MagicMock() - conn.open_shell = mock_open_shell - mock_send = MagicMock(return_value=b'command response') conn.send = mock_send # test sending a single command and converting to dict rc, out, err = conn.exec_command('command') self.assertEqual(out, b'command response') - self.assertTrue(mock_open_shell.called) - mock_send.assert_called_with({'command': b'command'}) - - mock_open_shell.reset_mock() + mock_send.assert_called_with(b'command', None, None, None) # test sending a json string rc, out, err = conn.exec_command(json.dumps({'command': 'command'})) self.assertEqual(out, b'command response') - mock_send.assert_called_with({'command': b'command'}) - self.assertTrue(mock_open_shell.called) + mock_send.assert_called_with(b'command', None, None, None) - mock_open_shell.reset_mock() conn._shell = MagicMock() # test _shell already open rc, out, err = conn.exec_command('command') self.assertEqual(out, b'command response') - self.assertFalse(mock_open_shell.called) - mock_send.assert_called_with({'command': b'command'}) + mock_send.assert_called_with(b'command', None, None, None) def test_network_cli_send(self): pc = PlayContext() @@ -163,7 +152,7 @@ class TestConnectionClass(unittest.TestCase): mock__shell.recv.return_value = response - output = conn.send({'command': b'command'}) + output = conn.send(b'command', None, None, None) mock__shell.sendall.assert_called_with(b'command\r') self.assertEqual(output, b'command response') @@ -172,5 +161,5 @@ class TestConnectionClass(unittest.TestCase): mock__shell.recv.return_value = b"ERROR: error message device#" with self.assertRaises(AnsibleConnectionFailure) as exc: - conn.send({'command': b'command'}) + conn.send(b'command', None, None, None) self.assertEqual(str(exc.exception), 'ERROR: error message device#')