diff --git a/bin/ansible-connection b/bin/ansible-connection index dd727743da4..7fbceedbd94 100755 --- a/bin/ansible-connection +++ b/bin/ansible-connection @@ -1,23 +1,6 @@ #!/usr/bin/env python - -# (c) 2017, Ansible, Inc. -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -######################################################## +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) __metaclass__ = type @@ -36,91 +19,68 @@ import socket import sys import time import traceback -import datetime import errno +import json from ansible import constants as C from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils.six import PY3 from ansible.module_utils.six.moves import cPickle from ansible.module_utils.connection import send_data, recv_data +from ansible.module_utils.service import fork_process from ansible.playbook.play_context import PlayContext from ansible.plugins.loader import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe -from ansible.errors import AnsibleConnectionFailure +from ansible.errors import AnsibleError from ansible.utils.display import Display +from ansible.utils.jsonrpc import JsonRpcServer -def do_fork(): +class ConnectionProcess(object): ''' - Does the required double fork for a daemon process. Based on - http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/ + The connection process wraps around a Connection object that manages + the connection to a remote device that persists over the playbook ''' - try: - pid = os.fork() - if pid > 0: - return pid - # This is done as a 'good practice' for daemons, but we need to keep the cwd - # leaving it here as a note that we KNOW its good practice but are not doing it on purpose. - # os.chdir("/") - os.setsid() - os.umask(0) - - try: - pid = os.fork() - if pid > 0: - sys.exit(0) - - if C.DEFAULT_LOG_PATH != '': - out_file = open(C.DEFAULT_LOG_PATH, 'ab+') - err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0) - else: - out_file = open('/dev/null', 'ab+') - err_file = open('/dev/null', 'ab+', 0) - - os.dup2(out_file.fileno(), sys.stdout.fileno()) - os.dup2(err_file.fileno(), sys.stderr.fileno()) - os.close(sys.stdin.fileno()) - - return pid - except OSError as e: - sys.exit(1) - except OSError as e: - sys.exit(1) - - -class Server(): - - def __init__(self, socket_path, play_context): - self.socket_path = socket_path + def __init__(self, fd, play_context, socket_path, original_path): self.play_context = play_context + self.socket_path = socket_path + self.original_path = original_path - display.display( - 'creating new control socket for host %s:%s as user %s' % - (play_context.remote_addr, play_context.port, play_context.remote_user), - log_only=True - ) - - display.display('control socket path is %s' % socket_path, log_only=True) - display.display('current working directory is %s' % os.getcwd(), log_only=True) - - self._start_time = datetime.datetime.now() - - display.display("using connection plugin %s" % self.play_context.connection, log_only=True) - - self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin) - self.connection._connect() - - if not self.connection.connected: - raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr) + self.fd = fd + self.exception = None - connection_time = datetime.datetime.now() - self._start_time - display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True) + self.srv = JsonRpcServer() + self.sock = None - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.socket.bind(self.socket_path) - self.socket.listen(1) - display.display('local socket is set to listening', log_only=True) + def start(self): + try: + messages = list() + result = {} + + messages.append('control socket path is %s' % self.socket_path) + + # If this is a relative path (~ gets expanded later) then plug the + # key's path on to the directory we originally came from, so we can + # find it now that our cwd is / + if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/': + self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file) + + self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null') + self.connection._connect() + self.srv.register(self.connection) + messages.append('connection to remote device started successfully') + + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.bind(self.socket_path) + self.sock.listen(1) + messages.append('local domain socket listeners started successfully') + except Exception as exc: + result['error'] = to_text(exc) + result['exception'] = traceback.format_exc() + finally: + result['messages'] = messages + self.fd.write(json.dumps(result)) + self.fd.close() def run(self): try: @@ -129,53 +89,36 @@ class Server(): signal.signal(signal.SIGTERM, self.handler) signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT) - (s, addr) = self.socket.accept() - display.display('incoming request accepted on persistent socket', log_only=True) + self.exception = None + (s, addr) = self.sock.accept() signal.alarm(0) + signal.signal(signal.SIGALRM, self.command_timeout) while True: data = recv_data(s) if not data: break - signal.signal(signal.SIGALRM, self.command_timeout) - signal.alarm(self.play_context.timeout) - - op = to_text(data.split(b':')[0]) - display.display('socket operation is %s' % op, log_only=True) - - method = getattr(self, 'do_%s' % op, None) - - rc = 255 - stdout = stderr = '' - - if not method: - stderr = 'Invalid action specified' - else: - rc, stdout, stderr = method(data) - + signal.alarm(self.connection._play_context.timeout) + resp = self.srv.handle_request(data) signal.alarm(0) - display.display('socket operation completed with rc %s' % rc, log_only=True) - - send_data(s, to_bytes(rc)) - send_data(s, to_bytes(stdout)) - send_data(s, to_bytes(stderr)) + send_data(s, to_bytes(resp)) s.close() except Exception as e: # socket.accept() will raise EINTR if the socket.close() is called - if e.errno != errno.EINTR: - display.display(traceback.format_exc(), log_only=True) + if hasattr(e, 'errno'): + if e.errno != errno.EINTR: + self.exception = traceback.format_exc() + else: + self.exception = traceback.format_exc() finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated self.shutdown() - end_time = datetime.datetime.now() - delta = end_time - self._start_time - display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True) def connect_timeout(self, signum, frame): display.display('persistent connection idle timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True) @@ -190,25 +133,25 @@ class Server(): self.shutdown() def shutdown(self): - display.display('shutdown persistent connection requested', log_only=True) - + """ Shuts down the local domain socket + """ if not os.path.exists(self.socket_path): - display.display('persistent connection is not active', log_only=True) return try: - if self.socket: - display.display('closing local listener', log_only=True) - self.socket.close() + if self.sock: + self.sock.close() if self.connection: - display.display('closing the connection', log_only=True) self.connection.close() + except: pass + finally: if os.path.exists(self.socket_path): - display.display('removing the local control socket', log_only=True) os.remove(self.socket_path) + setattr(self.connection, '_socket_path', None) + setattr(self.connection, '_connected', False) display.display('shutdown complete', log_only=True) @@ -262,6 +205,13 @@ def communicate(sock, data): def main(): + """ Called to initiate the connect to the remote device + """ + rc = 0 + result = {} + messages = list() + socket_path = None + # Need stdin as a byte stream if PY3: stdin = sys.stdin.buffer @@ -270,116 +220,91 @@ def main(): try: # read the play context data via stdin, which means depickling it - # FIXME: as noted above, we will probably need to deserialize the - # connection loader here as well at some point, otherwise this - # won't find role- or playbook-based connection plugins cur_line = stdin.readline() init_data = b'' + while cur_line.strip() != b'#END_INIT#': if cur_line == b'': raise Exception("EOF found before init data was complete") init_data += cur_line cur_line = stdin.readline() + if PY3: pc_data = cPickle.loads(init_data, encoding='bytes') else: pc_data = cPickle.loads(init_data) - pc = PlayContext() - pc.deserialize(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) except Exception as e: - # FIXME: better error message/handling/logging - sys.stderr.write(traceback.format_exc()) - sys.exit("FAIL: %s" % e) - - ssh = connection_loader.get('ssh', class_only=True) - cp = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user, pc.connection) - - # create the persistent connection dir if need be and create the paths - # which we will be using later - tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) - makedirs_safe(tmp_path) - lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) - socket_path = unfrackpath(cp % dict(directory=tmp_path)) - - # if the socket file doesn't exist, spin up the daemon process - lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600) - fcntl.lockf(lock_fd, fcntl.LOCK_EX) - - if not os.path.exists(socket_path): - pid = do_fork() - if pid == 0: - rc = 0 - try: - server = Server(socket_path, pc) - except AnsibleConnectionFailure as exc: - display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True) - display.display(str(exc), log_only=True) - rc = 1 - except Exception as exc: - display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True) - display.display(traceback.format_exc(), log_only=True) - rc = 1 - fcntl.lockf(lock_fd, fcntl.LOCK_UN) - os.close(lock_fd) - if rc == 0: - server.run() - sys.exit(rc) - else: - display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True) - - fcntl.lockf(lock_fd, fcntl.LOCK_UN) - os.close(lock_fd) - - timeout = pc.timeout - while bool(timeout): - if os.path.exists(socket_path): - display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr) - break - time.sleep(1) - timeout -= 1 - else: - raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr) - - # now connect to the daemon process - # FIXME: if the socket file existed but the daemonized process was killed, - # the connection will timeout here. Need to make this more resilient. - while True: - data = stdin.readline() - if data == b'': - break - if data.strip() == b'': - continue - - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - connect_retry_timeout = C.PERSISTENT_CONNECT_RETRY_TIMEOUT - while bool(connect_retry_timeout): - try: - sock.connect(socket_path) - break - except socket.error: - time.sleep(1) - connect_retry_timeout -= 1 - else: - display.display('connect retry timeout expired, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True) - display.display('persistent_connect_retry_timeout is %s secs' % (C.PERSISTENT_CONNECT_RETRY_TIMEOUT), pc.remote_addr, pc.remote_user, log_only=True) - sys.stderr.write('failed to connect to control socket') - sys.exit(255) + rc = 1 + result.update({ + 'error': to_text(e), + 'exception': traceback.format_exc() + }) + + if rc == 0: + ssh = connection_loader.get('ssh', class_only=True) + cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection) + + # create the persistent connection dir if need be and create the paths + # which we will be using later + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + makedirs_safe(tmp_path) + + lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path) + socket_path = unfrackpath(cp % dict(directory=tmp_path)) + + # if the socket file doesn't exist, spin up the daemon process + lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600) + fcntl.lockf(lock_fd, fcntl.LOCK_EX) + + if not os.path.exists(socket_path): + messages.append('local domain socket does not exist, starting it') + original_path = os.getcwd() + r, w = os.pipe() + pid = fork_process() + + if pid == 0: + try: + os.close(r) + wfd = os.fdopen(w, 'w') + process = ConnectionProcess(wfd, play_context, socket_path, original_path) + process.start() + except Exception as exc: + messages.append(traceback.format_exc()) + rc = 1 + + fcntl.lockf(lock_fd, fcntl.LOCK_UN) + os.close(lock_fd) + + if rc == 0: + process.run() + + sys.exit(rc) - # send the play_context back into the connection so the connection - # can handle any privilege escalation activities - pc_data = b'CONTEXT: %s' % init_data - communicate(sock, pc_data) + else: + os.close(w) + rfd = os.fdopen(r, 'r') + data = json.loads(rfd.read()) + messages.extend(data.pop('messages')) + result.update(data) - rc, stdout, stderr = communicate(sock, data.strip()) + else: + messages.append('found existing local domain socket, using it!') - sys.stdout.write(to_native(stdout)) - sys.stderr.write(to_native(stderr)) + result.update({ + 'messages': messages, + 'socket_path': socket_path + }) - sock.close() - break + if 'exception' in result: + rc = 1 + sys.stderr.write(json.dumps(result)) + else: + rc = 0 + sys.stdout.write(json.dumps(result)) sys.exit(rc) diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 78a1902bf65..efc1f828b78 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -1,31 +1,21 @@ # (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import os +import pty import time +import json +import subprocess import traceback from ansible import constants as C from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip from ansible.executor.task_result import TaskResult from ansible.module_utils.six import iteritems, string_types, binary_type +from ansible.module_utils.six.moves import cPickle from ansible.module_utils._text import to_text from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task @@ -490,6 +480,8 @@ class TaskExecutor: not getattr(self._connection, 'connected', False) or self._play_context.remote_addr != self._connection._play_context.remote_addr): self._connection = self._get_connection(variables=variables, templar=templar) + if getattr(self._connection, '_socket_path'): + variables['ansible_socket'] = self._connection._socket_path # only template the vars if the connection actually implements set_host_overrides # NB: this is expensive, and should be removed once connection-specific vars are being handled by play_context sho_impl = getattr(type(self._connection), 'set_host_overrides', None) @@ -736,12 +728,7 @@ class TaskExecutor: if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"): variables[i] = delegated_vars[i] - # if using persistent paramiko connections (or the action has set the FORCE_PERSISTENT_CONNECTION attribute to True), - # then we use the persistent connection plugion. Otherwise load the requested connection plugin - if C.USE_PERSISTENT_CONNECTIONS or getattr(self, 'FORCE_PERSISTENT_CONNECTION', False): - conn_type = 'persistent' - else: - conn_type = self._play_context.connection + conn_type = self._play_context.connection connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin) if not connection: @@ -749,6 +736,13 @@ class TaskExecutor: self._play_context.set_options_from_plugin(connection) + if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)): + display.vvvv('attempting to start connection', host=self._play_context.remote_addr) + display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr) + socket_path = self._start_connection() + display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr) + setattr(connection, '_socket_path', socket_path) + return connection def _get_action_handler(self, connection, templar): @@ -780,3 +774,42 @@ class TaskExecutor: raise AnsibleError("the handler '%s' was not found" % handler_name) return handler + + def _start_connection(self): + ''' + Starts the persistent connection + ''' + master, slave = pty.openpty() + p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdin = os.fdopen(master, 'wb', 0) + os.close(slave) + + # Need to force a protocol that is compatible with both py2 and py3. + # That would be protocol=2 or less. + # Also need to force a protocol that excludes certain control chars as + # stdin in this case is a pty and control chars will cause problems. + # that means only protocol=0 will work. + src = cPickle.dumps(self._play_context.serialize(), protocol=0) + stdin.write(src) + + stdin.write(b'\n#END_INIT#\n') + + (stdout, stderr) = p.communicate() + stdin.close() + + if p.returncode == 0: + result = json.loads(stdout) + else: + result = json.loads(stderr) + + if 'messages' in result: + for msg in result.get('messages'): + display.vvvv('%s' % msg, host=self._play_context.remote_addr) + + if 'error' in result: + if self._play_context.verbosity > 2: + msg = "The full traceback is:\n" + result['exception'] + display.display(result['exception'], color=C.COLOR_ERROR) + raise AnsibleError(result['error']) + + return result['socket_path'] diff --git a/lib/ansible/module_utils/connection.py b/lib/ansible/module_utils/connection.py index 676fcc42604..b43c2b8ba7a 100644 --- a/lib/ansible/module_utils/connection.py +++ b/lib/ansible/module_utils/connection.py @@ -27,6 +27,7 @@ # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +import json import socket import struct import traceback @@ -35,6 +36,7 @@ import uuid from functools import partial from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.six import iteritems def send_data(s, data): @@ -61,23 +63,14 @@ def recv_data(s): def exec_command(module, command): + connection = Connection(module._socket_path) try: - sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sf.connect(module._socket_path) - - data = "EXEC: %s" % command - send_data(sf, to_bytes(data.strip())) - - rc = int(recv_data(sf), 10) - stdout = recv_data(sf) - stderr = recv_data(sf) - except socket.error as e: - sf.close() - module.fail_json(msg='unable to connect to socket', err=to_native(e), exception=traceback.format_exc()) - - sf.close() - - return rc, to_native(stdout, errors='surrogate_or_strict'), to_native(stderr, errors='surrogate_or_strict') + out = connection.exec_command(command) + except ConnectionError as exc: + code = getattr(exc, 'code', 1) + message = getattr(exc, 'err', exc) + return code, '', to_text(message, errors='surrogate_then_replace') + return 0, out, '' def request_builder(method, *args, **kwargs): @@ -91,10 +84,19 @@ def request_builder(method, *args, **kwargs): return req +class ConnectionError(Exception): + + def __init__(self, message, *args, **kwargs): + super(ConnectionError, self).__init__(message) + for k, v in iteritems(kwargs): + setattr(self, k, v) + + class Connection: - def __init__(self, module): - self._module = module + def __init__(self, socket_path): + assert socket_path is not None, 'socket_path must be a value' + self.socket_path = socket_path def __getattr__(self, name): try: @@ -116,30 +118,40 @@ class Connection: req = request_builder(name, *args, **kwargs) reqid = req['id'] - if not self._module._socket_path: - self._module.fail_json(msg='provider support not available for this host') - - if not os.path.exists(self._module._socket_path): - self._module.fail_json(msg='provider socket does not exist, is the provider running?') + if not os.path.exists(self.socket_path): + raise ConnectionError('socket_path does not exist or cannot be found') try: - data = self._module.jsonify(req) - rc, out, err = exec_command(self._module, data) + data = json.dumps(req) + out = self.send(data) + response = json.loads(out) except socket.error as e: - self._module.fail_json(msg='unable to connect to socket', err=to_native(e), - exception=traceback.format_exc()) - - try: - response = self._module.from_json(to_text(out, errors='surrogate_then_replace')) - except ValueError as exc: - self._module.fail_json(msg=to_text(exc, errors='surrogate_then_replace')) + raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc()) if response['id'] != reqid: - self._module.fail_json(msg='invalid id received') + raise ConnectionError('invalid json-rpc id received') if 'error' in response: - msg = response['error'].get('data') or response['error']['message'] - self._module.fail_json(msg=to_text(msg, errors='surrogate_then_replace')) + err = response.get('error') + msg = err.get('data') or err['message'] + code = err['code'] + raise ConnectionError(to_text(msg, errors='surrogate_then_replace'), code=code) return response['result'] + + def send(self, data): + try: + sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sf.connect(self.socket_path) + + send_data(sf, to_bytes(data)) + response = recv_data(sf) + + except socket.error as e: + sf.close() + raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc()) + + sf.close() + + return to_text(response, errors='surrogate_or_strict') diff --git a/lib/ansible/module_utils/netconf.py b/lib/ansible/module_utils/netconf.py index d42e8b746aa..bc056007b00 100644 --- a/lib/ansible/module_utils/netconf.py +++ b/lib/ansible/module_utils/netconf.py @@ -27,6 +27,7 @@ # from contextlib import contextmanager +from ansible.module_utils._text import to_bytes, to_text from ansible.module_utils.connection import exec_command try: @@ -38,7 +39,7 @@ NS_MAP = {'nc': "urn:ietf:params:xml:ns:netconf:base:1.0"} def send_request(module, obj, check_rc=True, ignore_warning=True): - request = tostring(obj) + request = to_text(tostring(obj), errors='surrogate_or_strict') rc, out, err = exec_command(module, request) if rc != 0 and check_rc: error_root = fromstring(err) @@ -59,7 +60,7 @@ def send_request(module, obj, check_rc=True, ignore_warning=True): else: module.fail_json(msg=str(err)) return warnings - return fromstring(out) + return fromstring(to_bytes(out, errors='surrogate_or_strict')) def children(root, iterable): diff --git a/lib/ansible/module_utils/service.py b/lib/ansible/module_utils/service.py index 918d55da505..271b5d60645 100644 --- a/lib/ansible/module_utils/service.py +++ b/lib/ansible/module_utils/service.py @@ -91,33 +91,14 @@ def fail_if_missing(module, found, service, msg=''): module.fail_json(msg='Could not find the requested service %s: %s' % (service, msg)) -def daemonize(module, cmd): +def fork_process(): ''' - Execute a command while detaching as a daemon, returns rc, stdout, and stderr. - - :arg module: is an AnsibleModule object, used for it's utility methods - :arg cmd: is a list or string representing the command and options to run - - This is complex because daemonization is hard for people. - What we do is daemonize a part of this module, the daemon runs the command, - picks up the return code and output, and returns it to the main process. + This function performs the double fork process to detach from the + parent process and execute. ''' + pid = os.fork() - # init some vars - chunk = 4096 # FIXME: pass in as arg? - errors = 'surrogate_or_strict' - - # start it! - try: - pipe = os.pipe() - pid = os.fork() - except OSError: - module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc()) - - # we don't do any locking as this should be a unique module/process if pid == 0: - - os.close(pipe[0]) # Set stdin/stdout/stderr to /dev/null fd = os.open(os.devnull, os.O_RDWR) @@ -140,7 +121,7 @@ def daemonize(module, cmd): # get new process session and detach sid = os.setsid() if sid == -1: - module.fail_json(msg="Unable to detach session while daemonizing") + raise Exception("Unable to detach session while daemonizing") # avoid possible problems with cwd being removed os.chdir("/") @@ -149,6 +130,38 @@ def daemonize(module, cmd): if pid > 0: os._exit(0) + return pid + + +def daemonize(module, cmd): + ''' + Execute a command while detaching as a daemon, returns rc, stdout, and stderr. + + :arg module: is an AnsibleModule object, used for it's utility methods + :arg cmd: is a list or string representing the command and options to run + + This is complex because daemonization is hard for people. + What we do is daemonize a part of this module, the daemon runs the command, + picks up the return code and output, and returns it to the main process. + ''' + + # init some vars + chunk = 4096 # FIXME: pass in as arg? + errors = 'surrogate_or_strict' + + # start it! + try: + pipe = os.pipe() + pid = fork_process() + except OSError: + module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc()) + except Exception as exc: + module.fail_json(msg=to_text(exc), exception=traceback.format_exc()) + + # we don't do any locking as this should be a unique module/process + if pid == 0: + os.close(pipe[0]) + # if command is string deal with py2 vs py3 conversions for shlex if not isinstance(cmd, list): if PY2: diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py index addb037dbd4..1d37166fcdf 100644 --- a/lib/ansible/playbook/play_context.py +++ b/lib/ansible/playbook/play_context.py @@ -427,6 +427,8 @@ class PlayContext(Base): # if the final connection type is local, reset the remote_user value to that of the currently logged in user # this ensures any become settings are obeyed correctly # we store original in 'connection_user' for use of network/other modules that fallback to it as login user + # connection_user to be deprecated once connection=local is removed for + # network modules if new_info.connection == 'local': if not new_info.connection_user: new_info.connection_user = new_info.remote_user diff --git a/lib/ansible/plugins/action/__init__.py b/lib/ansible/plugins/action/__init__.py index 6324d009ac2..6b5ff9dcbc6 100644 --- a/lib/ansible/plugins/action/__init__.py +++ b/lib/ansible/plugins/action/__init__.py @@ -36,6 +36,7 @@ from ansible.module_utils.json_utils import _filter_non_json_lines from ansible.module_utils.six import binary_type, string_types, text_type, iteritems, with_metaclass from ansible.module_utils.six.moves import shlex_quote from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.connection import Connection from ansible.parsing.utils.jsonify import jsonify from ansible.release import __version__ from ansible.utils.unsafe_proxy import wrap_var @@ -604,7 +605,9 @@ class ActionBase(with_metaclass(ABCMeta, object)): module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS # give the module the socket for persistent connections - module_args['_ansible_socket'] = task_vars.get('ansible_socket') + module_args['_ansible_socket'] = getattr(self._connection, 'socket_path') + if not module_args['_ansible_socket']: + module_args['_ansible_socket'] = task_vars.get('ansible_socket') # make sure all commands use the designated shell executable module_args['_ansible_shell_executable'] = self._play_context.executable @@ -818,7 +821,8 @@ class ActionBase(with_metaclass(ABCMeta, object)): same_user = self._play_context.become_user == self._play_context.remote_user if sudoable and self._play_context.become and (allow_same_user or not same_user): display.debug("_low_level_execute_command(): using become for this command") - cmd = self._play_context.make_become_cmd(cmd, executable=executable) + if self._connection.transport != 'network_cli' and self._play_context.become_method != 'enable': + cmd = self._play_context.make_become_cmd(cmd, executable=executable) if self._connection.allow_executable: if executable is None: diff --git a/lib/ansible/plugins/action/eos.py b/lib/ansible/plugins/action/eos.py index a9cc4a359af..f96d81871cb 100644 --- a/lib/ansible/plugins/action/eos.py +++ b/lib/ansible/plugins/action/eos.py @@ -40,47 +40,35 @@ class ActionModule(_ActionModule): provider = load_provider(eos_provider_spec, self._task.args) transport = provider['transport'] or 'cli' - if self._play_context.connection != 'local' and transport == 'cli': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) - display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr) if transport == 'cli': - pc = copy.deepcopy(self._play_context) - pc.connection = 'network_cli' - pc.network_os = 'eos' - pc.remote_addr = provider['host'] or self._play_context.remote_addr - pc.port = int(provider['port'] or self._play_context.port or 22) - pc.remote_user = provider['username'] or self._play_context.connection_user - pc.password = provider['password'] or self._play_context.password - pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file - pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) - pc.become = provider['authorize'] or False - pc.become_pass = provider['auth_pass'] - - display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while '(config' in str(out): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') - - task_vars['ansible_socket'] = socket_path + if self._play_context.connection == 'local': + pc = copy.deepcopy(self._play_context) + pc.connection = 'network_cli' + pc.network_os = 'eos' + pc.remote_addr = provider['host'] or self._play_context.remote_addr + pc.port = int(provider['port'] or self._play_context.port or 22) + pc.remote_user = provider['username'] or self._play_context.connection_user + pc.password = provider['password'] or self._play_context.password + pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file + pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) + pc.become = provider['authorize'] or False + if pc.become: + pc.become_method = 'enable' + pc.become_pass = provider['auth_pass'] + + display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) + + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + task_vars['ansible_socket'] = socket_path else: provider['transport'] = 'eapi' diff --git a/lib/ansible/plugins/action/ios.py b/lib/ansible/plugins/action/ios.py index d17a0097501..b9a040d830a 100644 --- a/lib/ansible/plugins/action/ios.py +++ b/lib/ansible/plugins/action/ios.py @@ -38,50 +38,38 @@ class ActionModule(_ActionModule): def run(self, tmp=None, task_vars=None): - if self._play_context.connection != 'local': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) + if self._play_context.connection == 'local': + provider = load_provider(ios_provider_spec, self._task.args) - provider = load_provider(ios_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = 'network_cli' + pc.network_os = 'ios' + pc.remote_addr = provider['host'] or self._play_context.remote_addr + pc.port = int(provider['port'] or self._play_context.port or 22) + pc.remote_user = provider['username'] or self._play_context.connection_user + pc.password = provider['password'] or self._play_context.password + pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file + pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) + pc.become = provider['authorize'] or False + if pc.become: + pc.become_method = 'enable' + pc.become_pass = provider['auth_pass'] - pc = copy.deepcopy(self._play_context) - pc.connection = 'network_cli' - pc.network_os = 'ios' - pc.remote_addr = provider['host'] or self._play_context.remote_addr - pc.port = int(provider['port'] or self._play_context.port or 22) - pc.remote_user = provider['username'] or self._play_context.connection_user - pc.password = provider['password'] or self._play_context.password - pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file - pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) - pc.become = provider['authorize'] or False - pc.become_pass = provider['auth_pass'] + display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + task_vars['ansible_socket'] = socket_path - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') - - task_vars['ansible_socket'] = socket_path - - if self._play_context.become_method == 'enable': - self._play_context.become = False - self._play_context.become_method = None + if self._play_context.become_method == 'enable': + self._play_context.become = False + self._play_context.become_method = None result = super(ActionModule, self).run(tmp, task_vars) return result diff --git a/lib/ansible/plugins/action/iosxr.py b/lib/ansible/plugins/action/iosxr.py index 9e068baa729..c451436f605 100644 --- a/lib/ansible/plugins/action/iosxr.py +++ b/lib/ansible/plugins/action/iosxr.py @@ -38,43 +38,29 @@ class ActionModule(_ActionModule): def run(self, tmp=None, task_vars=None): - if self._play_context.connection != 'local': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) + if self._play_context.connection == 'local': + provider = load_provider(iosxr_provider_spec, self._task.args) - provider = load_provider(iosxr_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = 'network_cli' + pc.network_os = 'iosxr' + pc.remote_addr = provider['host'] or self._play_context.remote_addr + pc.port = int(provider['port'] or self._play_context.port or 22) + pc.remote_user = provider['username'] or self._play_context.connection_user + pc.password = provider['password'] or self._play_context.password + pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) - pc = copy.deepcopy(self._play_context) - pc.connection = 'network_cli' - pc.network_os = 'iosxr' - pc.remote_addr = provider['host'] or self._play_context.remote_addr - pc.port = int(provider['port'] or self._play_context.port or 22) - pc.remote_user = provider['username'] or self._play_context.connection_user - pc.password = provider['password'] or self._play_context.password - pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) + display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') - - task_vars['ansible_socket'] = socket_path + task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result diff --git a/lib/ansible/plugins/action/junos.py b/lib/ansible/plugins/action/junos.py index 54c5c4c0c7f..baa4f41751d 100644 --- a/lib/ansible/plugins/action/junos.py +++ b/lib/ansible/plugins/action/junos.py @@ -38,14 +38,6 @@ except ImportError: class ActionModule(_ActionModule): def run(self, tmp=None, task_vars=None): - - if self._play_context.connection != 'local': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) - module = module_loader._load_module_source(self._task.action, module_loader.find_plugin(self._task.action)) if not getattr(module, 'USE_PERSISTENT_CONNECTION', False): @@ -72,25 +64,27 @@ class ActionModule(_ActionModule): pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - - if pc.connection == 'network_cli': - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') + + if self._play_context.connection == 'local': + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) + + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + if pc.connection == 'network_cli': + # make sure we are in the right cli context which should be + # enable mode and not config module rc, out, err = connection.exec_command('prompt()') + while str(out).strip().endswith(')#'): + display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) + connection.exec_command('exit') + rc, out, err = connection.exec_command('prompt()') - task_vars['ansible_socket'] = socket_path + task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result diff --git a/lib/ansible/plugins/action/net_base.py b/lib/ansible/plugins/action/net_base.py index b43d0b810f0..bd0fa3a5727 100644 --- a/lib/ansible/plugins/action/net_base.py +++ b/lib/ansible/plugins/action/net_base.py @@ -37,13 +37,6 @@ except ImportError: class ActionModule(ActionBase): def run(self, tmp=None, task_vars=None): - if self._play_context.connection != 'local': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) - play_context = copy.deepcopy(self._play_context) play_context.network_os = self._get_network_os(task_vars) @@ -74,8 +67,9 @@ class ActionModule(ActionBase): play_context.become = self.provider['authorize'] or False play_context.become_pass = self.provider['auth_pass'] - socket_path = self._start_connection(play_context) - task_vars['ansible_socket'] = socket_path + if self._play_context.connection == 'local': + socket_path = self._start_connection(play_context) + task_vars['ansible_socket'] = socket_path if 'fail_on_missing_module' not in self._task.args: self._task.args['fail_on_missing_module'] = False diff --git a/lib/ansible/plugins/action/nxos.py b/lib/ansible/plugins/action/nxos.py index 10e08e4aa9d..491162ceb53 100644 --- a/lib/ansible/plugins/action/nxos.py +++ b/lib/ansible/plugins/action/nxos.py @@ -40,44 +40,31 @@ class ActionModule(_ActionModule): provider = load_provider(nxos_provider_spec, self._task.args) transport = provider['transport'] or 'cli' - if self._play_context.connection != 'local' and transport == 'cli': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) - display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr) if transport == 'cli': - pc = copy.deepcopy(self._play_context) - pc.connection = 'network_cli' - pc.network_os = 'nxos' - pc.remote_addr = provider['host'] or self._play_context.remote_addr - pc.port = int(provider['port'] or self._play_context.port or 22) - pc.remote_user = provider['username'] or self._play_context.connection_user - pc.password = provider['password'] or self._play_context.password - pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file - pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) - display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith(')#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') - - task_vars['ansible_socket'] = socket_path + if self._play_context.connection == 'local': + pc = copy.deepcopy(self._play_context) + pc.connection = 'network_cli' + pc.network_os = 'nxos' + pc.remote_addr = provider['host'] or self._play_context.remote_addr + pc.port = int(provider['port'] or self._play_context.port or 22) + pc.remote_user = provider['username'] or self._play_context.connection_user + pc.password = provider['password'] or self._play_context.password + pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file + pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) + display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) + + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) + + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + + task_vars['ansible_socket'] = socket_path else: provider['transport'] = 'nxapi' diff --git a/lib/ansible/plugins/action/vyos.py b/lib/ansible/plugins/action/vyos.py index f3aa72d0d3a..330e89a1c3b 100644 --- a/lib/ansible/plugins/action/vyos.py +++ b/lib/ansible/plugins/action/vyos.py @@ -37,13 +37,6 @@ except ImportError: class ActionModule(_ActionModule): def run(self, tmp=None, task_vars=None): - if self._play_context.connection != 'local': - return dict( - failed=True, - msg='invalid connection specified, expected connection=local, ' - 'got %s' % self._play_context.connection - ) - provider = load_provider(vyos_provider_spec, self._task.args) pc = copy.deepcopy(self._play_context) @@ -57,24 +50,18 @@ class ActionModule(_ActionModule): pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT) display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr) - connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - socket_path = connection.run() - display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) - if not socket_path: - return {'failed': True, - 'msg': 'unable to open shell. Please see: ' + - 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} + if self._play_context.connection == 'local': + connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin) - # make sure we are in the right cli context which should be - # enable mode and not config module - rc, out, err = connection.exec_command('prompt()') - while str(out).strip().endswith('#'): - display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr) - connection.exec_command('exit') - rc, out, err = connection.exec_command('prompt()') + socket_path = connection.run() + display.vvvv('socket_path: %s' % socket_path, pc.remote_addr) + if not socket_path: + return {'failed': True, + 'msg': 'unable to open shell. Please see: ' + + 'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'} - task_vars['ansible_socket'] = socket_path + task_vars['ansible_socket'] = socket_path result = super(ActionModule, self).run(tmp, task_vars) return result diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index 0eada4a2284..aac8324c40e 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -1,21 +1,7 @@ +# (c) 2012-2014, Michael DeHaan # (c) 2015 Toshio Kuratomi -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -# Make coding more python3-ish +# (c) 2017, Peter Sprygada +# (c) 2017 Ansible Project from __future__ import (absolute_import, division, print_function) __metaclass__ = type @@ -69,6 +55,11 @@ class ConnectionBase(AnsiblePlugin): module_implementation_preferences = ('',) allow_executable = True + # the following control whether or not the connection supports the + # persistent connection framework or not + supports_persistence = False + force_persistence = False + def __init__(self, play_context, new_stdin, *args, **kwargs): super(ConnectionBase, self).__init__() @@ -88,6 +79,8 @@ class ConnectionBase(AnsiblePlugin): self.prompt = None self._connected = False + self._socket_path = None + # load the shell plugin for this action/connection if play_context.shell: shell_type = play_context.shell @@ -110,6 +103,11 @@ class ConnectionBase(AnsiblePlugin): '''Read-only property holding whether the connection to the remote host is active or closed.''' return self._connected + @property + def socket_path(self): + '''Read-only property holding the connection socket path for this remote host''' + return self._socket_path + def _become_method_supported(self): ''' Checks if the current class supports this privilege escalation method ''' diff --git a/lib/ansible/plugins/connection/netconf.py b/lib/ansible/plugins/connection/netconf.py index 5c73c343b41..4b290a457c3 100644 --- a/lib/ansible/plugins/connection/netconf.py +++ b/lib/ansible/plugins/connection/netconf.py @@ -71,15 +71,14 @@ DOCUMENTATION = """ import os import logging -import json from ansible import constants as C from ansible.errors import AnsibleConnectionFailure, AnsibleError -from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE from ansible.plugins.loader import netconf_loader from ansible.plugins.connection import ConnectionBase, ensure_connect -from ansible.utils.jsonrpc import Rpc +from ansible.plugins.connection.local import Connection as LocalConnection try: from ncclient import manager @@ -98,11 +97,12 @@ except ImportError: logging.getLogger('ncclient').setLevel(logging.INFO) -class Connection(Rpc, ConnectionBase): +class Connection(ConnectionBase): """NetConf connections""" transport = 'netconf' has_pipelining = False + force_persistence = True def __init__(self, play_context, new_stdin, *args, **kwargs): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) @@ -113,18 +113,50 @@ class Connection(Rpc, ConnectionBase): self._manager = None self._connected = False + self._local = LocalConnection(play_context, new_stdin, *args, **kwargs) + + def exec_command(self, request, in_data=None, sudoable=True): + """Sends the request to the node and returns the reply + The method accepts two forms of request. The first form is as a byte + string that represents xml string be send over netconf session. + The second form is a json-rpc (2.0) byte string. + """ + if self._manager: + # to_ele operates on native strings + request = to_ele(to_native(request, errors='surrogate_or_strict')) + + if request is None: + return 'unable to parse request' + + try: + reply = self._manager.rpc(request) + except RPCError as exc: + return to_xml(exc.xml) + + return reply.data_xml + else: + return self._local.exec_command(request, in_data, sudoable) + + def put_file(self, in_path, out_path): + """Transfer a file from local to remote""" + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + """Fetch a file from remote to local""" + return self._local.fetch_file(in_path, out_path) + def _connect(self): super(Connection, self)._connect() - display.display('ssh connection done, stating ncclient', log_only=True) + display.display('ssh connection done, starting ncclient', log_only=True) - self.allow_agent = True + allow_agent = True if self._play_context.password is not None: - self.allow_agent = False + allow_agent = False - self.key_filename = None + key_filename = None if self._play_context.private_key_file: - self.key_filename = os.path.expanduser(self._play_context.private_key_file) + key_filename = os.path.expanduser(self._play_context.private_key_file) network_os = self._play_context.network_os @@ -149,16 +181,18 @@ class Connection(Rpc, ConnectionBase): port=self._play_context.port or 830, username=self._play_context.remote_user, password=self._play_context.password, - key_filename=str(self.key_filename), + key_filename=str(key_filename), hostkey_verify=C.HOST_KEY_CHECKING, look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS, - allow_agent=self.allow_agent, + allow_agent=allow_agent, timeout=self._play_context.timeout, device_params={'name': network_os}, ssh_config=ssh_config ) except SSHUnknownHostError as exc: raise AnsibleConnectionFailure(str(exc)) + except ImportError as exc: + raise AnsibleError("connection=netconf is not supported on {0}".format(network_os)) if not self._manager.connected: return 1, b'', b'not connected' @@ -169,7 +203,6 @@ class Connection(Rpc, ConnectionBase): self._netconf = netconf_loader.get(network_os, self) if self._netconf: - self._rpc.add(self._netconf) display.display('loaded netconf plugin for network_os %s' % network_os, log_only=True) else: display.display('unable to load netconf for network_os %s' % network_os) @@ -181,46 +214,3 @@ class Connection(Rpc, ConnectionBase): self._manager.close_session() self._connected = False super(Connection, self).close() - - @ensure_connect - def exec_command(self, request): - """Sends the request to the node and returns the reply - The method accepts two forms of request. The first form is as a byte - string that represents xml string be send over netconf session. - The second form is a json-rpc (2.0) byte string. - """ - try: - obj = json.loads(to_text(request, errors='surrogate_or_strict')) - - if 'jsonrpc' in obj: - if self._netconf: - out = self._exec_rpc(obj) - else: - out = self.internal_error("netconf plugin is not supported for network_os %s" % self._play_context.network_os) - return 0, to_bytes(out, errors='surrogate_or_strict'), b'' - else: - err = self.invalid_request(obj) - return 1, b'', to_bytes(err, errors='surrogate_or_strict') - - except (ValueError, TypeError): - # to_ele operates on native strings - request = to_native(request, errors='surrogate_or_strict') - - req = to_ele(request) - if req is None: - return 1, b'', b'unable to parse request' - - try: - reply = self._manager.rpc(req) - except RPCError as exc: - return 1, b'', to_bytes(to_xml(exc.xml), errors='surrogate_or_strict') - - return 0, to_bytes(reply.data_xml, errors='surrogate_or_strict'), b'' - - def put_file(self, in_path, out_path): - """Transfer a file from local to remote""" - pass - - def fetch_file(self, in_path, out_path): - """Fetch a file from remote to local""" - pass diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index 9ac535558d6..2bc44ef9bc6 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -47,6 +47,7 @@ DOCUMENTATION = """ import json import logging import re +import os import signal import socket import traceback @@ -57,9 +58,11 @@ from ansible import constants as C from ansible.errors import AnsibleConnectionFailure from ansible.module_utils.six import BytesIO, binary_type from ansible.module_utils._text import to_bytes, to_text -from ansible.plugins.loader import cliconf_loader, terminal_loader -from ansible.plugins.connection.paramiko_ssh import Connection as _Connection -from ansible.utils.jsonrpc import Rpc +from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.connection.local import Connection as LocalConnection +from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoSshConnection +from ansible.utils.path import unfrackpath, makedirs_safe try: from __main__ import display @@ -68,31 +71,73 @@ except ImportError: display = Display() -class Connection(Rpc, _Connection): +class Connection(ConnectionBase): ''' CLI (shell) SSH connections on Paramiko ''' transport = 'network_cli' has_pipelining = True + force_persistence = True def __init__(self, play_context, new_stdin, *args, **kwargs): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) - self._terminal = None - self._cliconf = None - self._shell = None + self.ssh = None + self._ssh_shell = None + self._matched_prompt = None self._matched_pattern = None self._last_response = None self._history = list() - self._play_context = play_context - if play_context.verbosity > 3: + self._local = LocalConnection(play_context, new_stdin, *args, **kwargs) + + self._terminal = None + self._cliconf = None + + if self._play_context.verbosity > 3: logging.getLogger('paramiko').setLevel(logging.DEBUG) + # reconstruct the socket_path and set instance values accordingly + self._update_connection_state() + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if name.startswith('_'): + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) + return getattr(self._cliconf, name) + + def exec_command(self, cmd, in_data=None, sudoable=True): + # this try..except block is just to handle the transition to supporting + # network_cli as a toplevel connection. Once connection=local is gone, + # this block can be removed as well and all calls passed directly to + # the local connection + if self._ssh_shell: + try: + cmd = json.loads(to_text(cmd, errors='surrogate_or_strict')) + kwargs = {'command': to_bytes(cmd['command'], errors='surrogate_or_strict')} + for key in ('prompts', 'answer', 'send_only'): + if key in cmd: + kwargs[key] = to_bytes(cmd[key], errors='surrogate_or_strict') + return self.send(**kwargs) + except ValueError: + cmd = to_bytes(cmd, errors='surrogate_or_strict') + return self.send(command=cmd) + + else: + return self._local.exec_command(cmd, in_data, sudoable) + + def put_file(self, in_path, out_path): + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + return self._local.fetch_file(in_path, out_path) + def update_play_context(self, play_context): """Updates the play context information for the connection""" - display.display('updating play_context for connection', log_only=True) + display.vvvv('updating play_context for connection', host=self._play_context.remote_addr) if self._play_context.become is False and play_context.become is True: auth_pass = play_context.become_pass @@ -104,17 +149,22 @@ class Connection(Rpc, _Connection): self._play_context = play_context def _connect(self): - """Connections to the device and sets the terminal type""" + ''' + Connects to the remote device and starts the terminal + ''' + if self.connected: + return if self._play_context.password and not self._play_context.private_key_file: C.PARAMIKO_LOOK_FOR_KEYS = False - super(Connection, self)._connect() + ssh = ParamikoSshConnection(self._play_context, '/dev/null')._connect() + self.ssh = ssh.ssh - display.display('ssh connection done, setting terminal', log_only=True) + display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr) - self._shell = self.ssh.invoke_shell() - self._shell.settimeout(self._play_context.timeout) + self._ssh_shell = self.ssh.invoke_shell() + self._ssh_shell.settimeout(self._play_context.timeout) network_os = self._play_context.network_os if not network_os: @@ -127,53 +177,83 @@ class Connection(Rpc, _Connection): if not self._terminal: raise AnsibleConnectionFailure('network os %s is not supported' % network_os) - display.display('loaded terminal plugin for network_os %s' % network_os, log_only=True) + display.vvvv('loaded terminal plugin for network_os %s' % network_os, host=self._play_context.remote_addr) self._cliconf = cliconf_loader.get(network_os, self) if self._cliconf: - self._rpc.add(self._cliconf) - display.display('loaded cliconf plugin for network_os %s' % network_os, log_only=True) + display.vvvv('loaded cliconf plugin for network_os %s' % network_os, host=self._play_context.remote_addr) else: - display.display('unable to load cliconf for network_os %s' % network_os) + display.vvvv('unable to load cliconf for network_os %s' % network_os) self.receive() - display.display('firing event: on_open_shell()', log_only=True) + display.vvvv('firing event: on_open_shell()', host=self._play_context.remote_addr) self._terminal.on_open_shell() - if getattr(self._play_context, 'become', None): - display.display('firing event: on_authorize', log_only=True) + if self._play_context.become and self._play_context.become_method == 'enable': + display.vvvv('firing event: on_authorize', host=self._play_context.remote_addr) auth_pass = self._play_context.become_pass self._terminal.on_authorize(passwd=auth_pass) + display.vvvv('ssh connection has completed successfully', host=self._play_context.remote_addr) self._connected = True - display.display('ssh connection has completed successfully', log_only=True) - def close(self): - """Close the active connection to the device - """ - display.display("closing ssh connection to device", log_only=True) - if self._shell: - display.display("firing event: on_close_shell()", log_only=True) - self._terminal.on_close_shell() - self._shell.close() - self._shell = None - display.display("cli session is now closed", log_only=True) + return self + + def _update_connection_state(self): + ''' + Reconstruct the connection socket_path and check if it exists + + If the socket path exists then the connection is active and set + both the _socket_path value to the path and the _connected value + to True. If the socket path doesn't exist, leave the socket path + value to None and the _connected value to False + ''' + ssh = connection_loader.get('ssh', class_only=True) + cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user) + + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + socket_path = unfrackpath(cp % dict(directory=tmp_path)) - super(Connection, self).close() + if os.path.exists(socket_path): + self._connected = True + self._socket_path = socket_path - self._connected = False - display.display("ssh connection has been closed successfully", log_only=True) + def reset(self): + ''' + Reset the connection + ''' + if self._socket_path: + display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr) + self.shutdown() + + def close(self): + ''' + Close the active connection to the device + ''' + # only close the connection if its connected. + if self._connected: + display.debug("closing ssh connection to device") + if self._ssh_shell: + display.debug("firing event: on_close_shell()") + self._terminal.on_close_shell() + self._ssh_shell.close() + self._ssh_shell = None + display.debug("cli session is now closed") + self._connected = False + display.debug("ssh connection has been closed successfully") def receive(self, command=None, prompts=None, answer=None): - """Handles receiving of output from command""" + ''' + Handles receiving of output from command + ''' recv = BytesIO() handled = False self._matched_prompt = None while True: - data = self._shell.recv(256) + data = self._ssh_shell.recv(256) recv.write(data) offset = recv.tell() - 256 if recv.tell() > 256 else 0 @@ -190,25 +270,30 @@ class Connection(Rpc, _Connection): return self._sanitize(resp, command) def send(self, command, prompts=None, answer=None, send_only=False): - """Sends the command to the device in the opened shell""" + ''' + Sends the command to the device in the opened shell + ''' try: self._history.append(command) - self._shell.sendall(b'%s\r' % command) + self._ssh_shell.sendall(b'%s\r' % command) if send_only: return - return self.receive(command, prompts, answer) + response = self.receive(command, prompts, answer) + return to_text(response, errors='surrogate_or_strict') except (socket.timeout, AttributeError): - display.display(traceback.format_exc(), log_only=True) + display.vvvv(traceback.format_exc(), host=self._play_context.remote_addr) raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) def _strip(self, data): - """Removes ANSI codes from device response""" + ''' + Removes ANSI codes from device response + ''' for regex in self._terminal.ansi_re: data = regex.sub(b'', data) return data def _handle_prompt(self, resp, prompts, answer): - """ + ''' Matches the command prompt and responds :arg resp: Byte string containing the raw response from the remote @@ -216,17 +301,19 @@ class Connection(Rpc, _Connection): :arg answer: Byte string to send back to the remote if we find a prompt. A carriage return is automatically appended to this string. :returns: True if a prompt was found in ``resp``. False otherwise - """ + ''' prompts = [re.compile(r, re.I) for r in prompts] for regex in prompts: match = regex.search(resp) if match: - self._shell.sendall(b'%s\r' % answer) + self._ssh_shell.sendall(b'%s\r' % answer) return True return False def _sanitize(self, resp, command=None): - """Removes elements from the response before returning to the caller""" + ''' + Removes elements from the response before returning to the caller + ''' cleaned = [] for line in resp.splitlines(): if (command and line.strip() == command.strip()) or self._matched_prompt.strip() in line: @@ -235,7 +322,8 @@ class Connection(Rpc, _Connection): return b'\n'.join(cleaned).strip() def _find_prompt(self, response): - """Searches the buffered response for a matching command prompt""" + '''Searches the buffered response for a matching command prompt + ''' errored_response = None is_error_message = False for regex in self._terminal.terminal_stderr_re: @@ -264,64 +352,3 @@ class Connection(Rpc, _Connection): raise AnsibleConnectionFailure(errored_response) return False - - def alarm_handler(self, signum, frame): - """Alarm handler raised in case of command timeout """ - display.display('closing shell due to sigalarm', log_only=True) - self.close() - - def exec_command(self, cmd): - """Executes the cmd on in the shell and returns the output - - The method accepts three forms of cmd. The first form is as a byte - string that represents the command to be executed in the shell. The - second form is as a utf8 JSON byte string with additional keywords. - The third form is a json-rpc (2.0) - Keywords supported for cmd: - :command: the command string to execute - :prompt: the expected prompt generated by executing command. - This can be a string or a list of strings - :answer: the string to respond to the prompt with - :sendonly: bool to disable waiting for response - :arg cmd: the byte string that represents the command to be executed - which can be a single command or a json encoded string. - :returns: a tuple of (return code, stdout, stderr). The return - code is an integer and stdout and stderr are byte strings - """ - try: - obj = json.loads(to_text(cmd, errors='surrogate_or_strict')) - except (ValueError, TypeError): - obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')} - - obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items()) - if 'prompt' in obj: - if isinstance(obj['prompt'], binary_type): - # Prompt was a string - obj['prompt'] = [obj['prompt']] - elif not isinstance(obj['prompt'], Sequence): - # Convert nonstrings into byte strings (to_bytes(5) => b'5') - if obj['prompt'] is not None: - obj['prompt'] = [to_bytes(obj['prompt'], errors='surrogate_or_strict')] - else: - # Prompt was a Sequence of strings. Make sure they're byte strings - obj['prompt'] = [to_bytes(p, errors='surrogate_or_strict') for p in obj['prompt'] if p is not None] - - if 'jsonrpc' in obj: - if self._cliconf: - out = self._exec_rpc(obj) - else: - out = self.internal_error("cliconf is not supported for network_os %s" % self._play_context.network_os) - return 0, to_bytes(out, errors='surrogate_or_strict'), b'' - - if obj['command'] == b'prompt()': - return 0, self._matched_prompt, b'' - - try: - if not signal.getsignal(signal.SIGALRM): - signal.signal(signal.SIGALRM, self.alarm_handler) - signal.alarm(self._play_context.timeout) - out = self.send(obj['command'], obj.get('prompt'), obj.get('answer'), obj.get('sendonly')) - signal.alarm(0) - return 0, out, b'' - except (AnsibleConnectionFailure, ValueError) as exc: - return 1, b'', to_bytes(exc) diff --git a/lib/ansible/plugins/connection/paramiko_ssh.py b/lib/ansible/plugins/connection/paramiko_ssh.py index e6ecaac0556..1ae551c7777 100644 --- a/lib/ansible/plugins/connection/paramiko_ssh.py +++ b/lib/ansible/plugins/connection/paramiko_ssh.py @@ -100,6 +100,7 @@ with warnings.catch_warnings(): class MyAddPolicy(object): """ Based on AutoAddPolicy in paramiko so we can determine when keys are added + and also prompt for input. Policy for automatically adding the hostname and new host key to the @@ -114,8 +115,13 @@ class MyAddPolicy(object): if all((C.HOST_KEY_CHECKING, not C.PARAMIKO_HOST_KEY_AUTO_ADD)): + fingerprint = hexlify(key.get_fingerprint()) + ktype = key.get_name() + if C.USE_PERSISTENT_CONNECTIONS: - raise AnsibleConnectionFailure('rejected %s host key for host %s: %s' % (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + # don't print the prompt string since the user cannot respond + # to the question anyway + raise AnsibleError(AUTHENTICITY_MSG[1:92] % (hostname, ktype, fingerprint)) self.connection.connection_lock() @@ -125,9 +131,6 @@ class MyAddPolicy(object): # clear out any premature input on sys.stdin tcflush(sys.stdin, TCIFLUSH) - fingerprint = hexlify(key.get_fingerprint()) - ktype = key.get_name() - inp = input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint)) sys.stdin = old_stdin diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py index 1e51eceed29..7ade4f4bc78 100644 --- a/lib/ansible/plugins/connection/persistent.py +++ b/lib/ansible/plugins/connection/persistent.py @@ -1,4 +1,4 @@ -# (c) 2017 Red Hat Inc. +# 2017 Red Hat Inc. # (c) 2017 Ansible Project # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) @@ -13,15 +13,19 @@ DOCUMENTATION = """ - This is a helper plugin to allow making other connections persistent. version_added: "2.3" """ - -import re import os +import sys import pty +import json import subprocess -from ansible.module_utils._text import to_bytes, to_text -from ansible.module_utils.six.moves import cPickle +from ansible import constants as C +from ansible.plugins.loader import connection_loader from ansible.plugins.connection import ConnectionBase +from ansible.module_utils._text import to_text +from ansible.module_utils.six.moves import cPickle +from ansible.module_utils.connection import Connection as SocketConnection +from ansible.errors import AnsibleError try: from __main__ import display @@ -40,8 +44,38 @@ class Connection(ConnectionBase): self._connected = True return self - def _do_it(self, action): + def exec_command(self, cmd, in_data=None, sudoable=True): + display.vvvv('exec_command(), socket_path=%s' % self.socket_path, host=self._play_context.remote_addr) + connection = SocketConnection(self.socket_path) + out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable) + return 0, out, '' + + def put_file(self, in_path, out_path): + pass + + def fetch_file(self, in_path, out_path): + pass + + def close(self): + self._connected = False + + def run(self): + """Returns the path of the persistent connection socket. + + Attempts to ensure (within playcontext.timeout seconds) that the + socket path exists. If the path exists (or the timeout has expired), + returns the socket path. + """ + display.vvvv('starting connection from persistent connection plugin', host=self._play_context.remote_addr) + socket_path = self._start_connection() + display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr) + setattr(self, '_socket_path', socket_path) + return socket_path + def _start_connection(self): + ''' + Starts the persistent connection + ''' master, slave = pty.openpty() p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdin = os.fdopen(master, 'wb', 0) @@ -56,40 +90,23 @@ class Connection(ConnectionBase): stdin.write(src) stdin.write(b'\n#END_INIT#\n') - stdin.write(to_bytes(action)) - stdin.write(b'\n\n') (stdout, stderr) = p.communicate() stdin.close() - return (p.returncode, stdout, stderr) - - def exec_command(self, cmd, in_data=None, sudoable=True): - super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) - return self._do_it('EXEC: ' + cmd) - - def put_file(self, in_path, out_path): - super(Connection, self).put_file(in_path, out_path) - self._do_it('PUT: %s %s' % (in_path, out_path)) - - def fetch_file(self, in_path, out_path): - super(Connection, self).fetch_file(in_path, out_path) - self._do_it('FETCH: %s %s' % (in_path, out_path)) - - def close(self): - self._connected = False + if p.returncode == 0: + result = json.loads(to_text(stdout, errors='surrogate_then_replace')) + else: + result = json.loads(to_text(stderr, errors='surrogate_then_replace')) - def run(self): - """Returns the path of the persistent connection socket. + if 'messages' in result: + for msg in result.get('messages'): + display.vvvv('%s' % msg, host=self._play_context.remote_addr) - Attempts to ensure (within playcontext.timeout seconds) that the - socket path exists. If the path exists (or the timeout has expired), - returns the socket path. - """ - socket_path = None - rc, out, err = self._do_it('RUN:') - match = re.search(br"#SOCKET_PATH#: (\S+)", out) - if match: - socket_path = to_text(match.group(1).strip(), errors='surrogate_or_strict') + if 'error' in result: + if self._play_context.verbosity > 2: + msg = "The full traceback is:\n" + result['exception'] + display.display(result['exception'], color=C.COLOR_ERROR) + raise AnsibleError(result['error']) - return socket_path + return result['socket_path'] diff --git a/lib/ansible/plugins/terminal/__init__.py b/lib/ansible/plugins/terminal/__init__.py index e52ae622737..ea499cc518a 100644 --- a/lib/ansible/plugins/terminal/__init__.py +++ b/lib/ansible/plugins/terminal/__init__.py @@ -56,20 +56,12 @@ class TerminalBase(with_metaclass(ABCMeta, object)): self._connection = connection def _exec_cli_command(self, cmd, check_rc=True): - """ - Executes a CLI command on the device - - :arg cmd: Byte string consisting of the command to execute - :kwarg check_rc: If True, the default, raise an - :exc:`AnsibleConnectionFailure` if the return code from the - command is nonzero - :returns: A tuple of return code, stdout, and stderr from running the - command. stdout and stderr are both byte strings. - """ - rc, out, err = self._connection.exec_command(cmd) - if check_rc and rc != 0: - raise AnsibleConnectionFailure(err) - return rc, out, err + ''' + Executes the CLI command on the remote device and returns the output + + :arg cmd: Byte string command to be executed + ''' + return self._connection.exec_command(cmd) def _get_prompt(self): """ @@ -77,9 +69,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)): :returns: A byte string of the prompt """ - for cmd in (b'\n', b'prompt()'): - rc, out, err = self._exec_cli_command(cmd) - return out + self._exec_cli_command(b'\n') + return self._connection._matched_prompt def on_open_shell(self): """Called after the SSH session is established diff --git a/lib/ansible/plugins/terminal/junos.py b/lib/ansible/plugins/terminal/junos.py index a284220dc42..63d1059025f 100644 --- a/lib/ansible/plugins/terminal/junos.py +++ b/lib/ansible/plugins/terminal/junos.py @@ -36,21 +36,21 @@ except ImportError: class TerminalModule(TerminalBase): terminal_stdout_re = [ - re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"), + re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"), ] terminal_stderr_re = [ - re.compile(r"unknown command"), - re.compile(r"syntax error,") + re.compile(br"unknown command"), + re.compile(br"syntax error,") ] def on_open_shell(self): try: prompt = self._get_prompt() - if prompt.strip().endswith('%'): + if prompt.strip().endswith(b'%'): display.vvv('starting cli', self._connection._play_context.remote_addr) self._exec_cli_command('cli') - for c in ['set cli timestamp disable', 'set cli screen-length 0', 'set cli screen-width 1024']: + for c in (b'set cli timestamp disable', b'set cli screen-length 0', b'set cli screen-width 1024'): self._exec_cli_command(c) except AnsibleConnectionFailure: raise AnsibleConnectionFailure('unable to set terminal parameters') diff --git a/lib/ansible/utils/jsonrpc.py b/lib/ansible/utils/jsonrpc.py index 5a4843e6563..12ef7210cb8 100644 --- a/lib/ansible/utils/jsonrpc.py +++ b/lib/ansible/utils/jsonrpc.py @@ -1,28 +1,16 @@ -# -# (c) 2016 Red Hat Inc. -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . +# (c) 2017, Peter Sprygada +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) __metaclass__ = type import json import traceback +from ansible import constants as C from ansible.module_utils._text import to_text + try: from __main__ import display except ImportError: @@ -30,13 +18,13 @@ except ImportError: display = Display() -class Rpc: +class JsonRpcServer(object): + + _objects = set() - def __init__(self, *args, **kwargs): - self._rpc = set() - super(Rpc, self).__init__(*args, **kwargs) + def handle_request(self, request): + request = json.loads(to_text(request, errors='surrogate_then_replace')) - def _exec_rpc(self, request): method = request.get('method') if method.startswith('rpc.') or method.startswith('_'): @@ -45,6 +33,7 @@ class Rpc: params = request.get('params') setattr(self, '_identifier', request.get('id')) + args = [] kwargs = {} @@ -54,10 +43,15 @@ class Rpc: kwargs = params rpc_method = None - for obj in self._rpc: - rpc_method = getattr(obj, method, None) - if rpc_method: - break + + if method in ('shutdown', 'reset'): + rpc_method = getattr(self, 'shutdown') + + else: + for obj in self._objects: + rpc_method = getattr(obj, method, None) + if rpc_method: + break if not rpc_method: error = self.method_not_found() @@ -66,7 +60,7 @@ class Rpc: try: result = rpc_method(*args, **kwargs) except Exception as exc: - display.display(traceback.format_exc(), log_only=True) + display.vvv(traceback.format_exc()) error = self.internal_error(data=to_text(exc, errors='surrogate_then_replace')) response = json.dumps(error) else: @@ -78,8 +72,12 @@ class Rpc: response = json.dumps(response) delattr(self, '_identifier') + return response + def register(self, obj): + self._objects.add(obj) + def header(self): return {'jsonrpc': '2.0', 'id': self._identifier} diff --git a/test/units/plugins/action/test_action.py b/test/units/plugins/action/test_action.py index 6adaff32e2b..9f460b065b5 100644 --- a/test/units/plugins/action/test_action.py +++ b/test/units/plugins/action/test_action.py @@ -405,6 +405,7 @@ class TestActionBase(unittest.TestCase): mock_connection = MagicMock() mock_connection.build_module_command.side_effect = build_module_command + mock_connection.socket_path = None mock_connection._shell.get_remote_filename.return_value = 'copy.py' mock_connection._shell.join_path.side_effect = os.path.join diff --git a/test/units/plugins/connection/test_connection.py b/test/units/plugins/connection/test_connection.py index 99e114e6ac6..0705e495d7c 100644 --- a/test/units/plugins/connection/test_connection.py +++ b/test/units/plugins/connection/test_connection.py @@ -37,6 +37,7 @@ from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoConnec from ansible.plugins.connection.ssh import Connection as SSHConnection from ansible.plugins.connection.docker import Connection as DockerConnection # from ansible.plugins.connection.winrm import Connection as WinRmConnection +from ansible.plugins.connection.netconf import Connection as NetconfConnection from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection @@ -140,7 +141,9 @@ class TestConnectionBaseClass(unittest.TestCase): def test_network_cli_connection_module(self): self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection) - self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), ParamikoConnection) + + def test_netconf_connection_module(self): + self.assertIsInstance(NetconfConnection(self.play_context, self.in_stream), NetconfConnection) def test_check_password_prompt(self): local = ( diff --git a/test/units/plugins/connection/test_netconf.py b/test/units/plugins/connection/test_netconf.py index 9552245a2bc..af767ecf36d 100644 --- a/test/units/plugins/connection/test_netconf.py +++ b/test/units/plugins/connection/test_netconf.py @@ -69,9 +69,9 @@ class TestNetconfConnectionClass(unittest.TestCase): conn = netconf.Connection(pc, new_stdin) - mock_manager = MagicMock(name='self._manager.connect') - type(mock_manager).session_id = PropertyMock(return_value='123456789') - netconf.manager.connect.return_value = mock_manager + mock_manager = MagicMock() + mock_manager.session_id = '123456789' + netconf.manager.connect = MagicMock(return_value=mock_manager) conn._play_context.network_os = 'default' rc, out, err = conn._connect() @@ -88,22 +88,16 @@ class TestNetconfConnectionClass(unittest.TestCase): conn = netconf.Connection(pc, new_stdin) conn._connected = True - mock_manager = MagicMock(name='self._manager') - mock_reply = MagicMock(name='reply') type(mock_reply).data_xml = PropertyMock(return_value='') + mock_manager = MagicMock(name='self._manager') mock_manager.rpc.return_value = mock_reply - conn._manager = mock_manager - rc, out, err = conn.exec_command('') - - netconf.to_ele.assert_called_with('') + out = conn.exec_command('') - self.assertEqual(0, rc) - self.assertEqual(b'', out) - self.assertEqual(b'', err) + self.assertEqual('', out) def test_netconf_exec_command_invalid_request(self): pc = PlayContext() @@ -112,10 +106,11 @@ class TestNetconfConnectionClass(unittest.TestCase): conn = netconf.Connection(pc, new_stdin) conn._connected = True + mock_manager = MagicMock(name='self._manager') + conn._manager = mock_manager + netconf.to_ele.return_value = None - rc, out, err = conn.exec_command('test string') + out = conn.exec_command('test string') - self.assertEqual(1, rc) - self.assertEqual(b'', out) - self.assertEqual(b'unable to parse request', err) + self.assertEqual('unable to parse request', out) diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index cb93aceb9a0..899d1dc353c 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -35,7 +35,7 @@ from ansible.plugins.connection import network_cli class TestConnectionClass(unittest.TestCase): - @patch("ansible.plugins.connection.network_cli._Connection._connect") + @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") def test_network_cli__connect_error(self, mocked_super): pc = PlayContext() new_stdin = StringIO() @@ -47,7 +47,7 @@ class TestConnectionClass(unittest.TestCase): pc.network_os = None self.assertRaises(AnsibleConnectionFailure, conn._connect) - @patch("ansible.plugins.connection.network_cli._Connection._connect") + @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") def test_network_cli__invalid_os(self, mocked_super): pc = PlayContext() new_stdin = StringIO() @@ -60,7 +60,7 @@ class TestConnectionClass(unittest.TestCase): self.assertRaises(AnsibleConnectionFailure, conn._connect) @patch("ansible.plugins.connection.network_cli.terminal_loader") - @patch("ansible.plugins.connection.network_cli._Connection._connect") + @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): pc = PlayContext() new_stdin = StringIO() @@ -70,22 +70,21 @@ class TestConnectionClass(unittest.TestCase): conn.ssh = MagicMock() conn.receive = MagicMock() - - mock_terminal = MagicMock() - conn._terminal = mock_terminal + conn._terminal = MagicMock() conn._connect() self.assertTrue(conn._terminal.on_open_shell.called) self.assertFalse(conn._terminal.on_authorize.called) conn._play_context.become = True + conn._play_context.become_method = 'enable' conn._play_context.become_pass = 'password' + conn._connected = False conn._connect() - conn._terminal.on_authorize.assert_called_with(passwd='password') - @patch("ansible.plugins.connection.network_cli._Connection.close") + @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection.close") def test_network_cli_close(self, mocked_super): pc = PlayContext() new_stdin = StringIO() @@ -93,20 +92,14 @@ class TestConnectionClass(unittest.TestCase): terminal = MagicMock(supports_multiplexing=False) conn._terminal = terminal - - conn.close() - - conn._shell = MagicMock() + conn._ssh_shell = MagicMock() + conn._connected = True conn.close() self.assertTrue(terminal.on_close_shell.called) + self.assertIsNone(conn._ssh_shell) - terminal.supports_multiplexing = True - - conn.close() - self.assertIsNone(conn._shell) - - @patch("ansible.plugins.connection.network_cli._Connection._connect") + @patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect") def test_network_cli_exec_command(self, mocked_super): pc = PlayContext() new_stdin = StringIO() @@ -114,23 +107,17 @@ class TestConnectionClass(unittest.TestCase): mock_send = MagicMock(return_value=b'command response') conn.send = mock_send + conn._ssh_shell = MagicMock() # test sending a single command and converting to dict - rc, out, err = conn.exec_command('command') + out = conn.exec_command('command') self.assertEqual(out, b'command response') - mock_send.assert_called_with(b'command', None, None, None) + mock_send.assert_called_with(command=b'command') # test sending a json string - rc, out, err = conn.exec_command(json.dumps({'command': 'command'})) - self.assertEqual(out, b'command response') - mock_send.assert_called_with(b'command', None, None, None) - - conn._shell = MagicMock() - - # test _shell already open - rc, out, err = conn.exec_command('command') + out = conn.exec_command(json.dumps({'command': 'command'})) self.assertEqual(out, b'command response') - mock_send.assert_called_with(b'command', None, None, None) + mock_send.assert_called_with(command=b'command') def test_network_cli_send(self): pc = PlayContext() @@ -142,7 +129,7 @@ class TestConnectionClass(unittest.TestCase): conn._terminal = mock__terminal mock__shell = MagicMock() - conn._shell = mock__shell + conn._ssh_shell = mock__shell response = b"""device#command command response @@ -155,7 +142,7 @@ class TestConnectionClass(unittest.TestCase): output = conn.send(b'command', None, None, None) mock__shell.sendall.assert_called_with(b'command\r') - self.assertEqual(output, b'command response') + self.assertEqual(output, 'command response') mock__shell.reset_mock() mock__shell.recv.return_value = b"ERROR: error message device#"