Connection plugins network_cli and netconf (#32521)

* implements jsonrpc message passing for ansible-connection

* implements more generic mechanism for persistent connections
* starts persistent connection in task_executor if enabled and supported
* supports using network_cli as top level connection plugin
* enhances logging for persistent connection to stdout

* Update action plugins

* Fix Python3 RPC

* Fix Junos bytes<-->str issues

* supports using netconf as top level connection plugin

* Error message when running netconf on an unsupported platform
* Update tests

* Fix `authorize: yes` for `connection: local`

* Handle potentially JSON data in terminal

* Add clarifying detail if possible on ConnectionError
pull/32770/head
Nathaniel Case 7 years ago committed by GitHub
parent 897b31f249
commit 9c0275a879
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,23 +1,6 @@
#!/usr/bin/env python
# (c) 2017, Ansible, Inc. <support@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
########################################################
# Copyright: (c) 2017, Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
@ -36,91 +19,68 @@ import socket
import sys
import time
import traceback
import datetime
import errno
import json
from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import send_data, recv_data
from ansible.module_utils.service import fork_process
from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.errors import AnsibleConnectionFailure
from ansible.errors import AnsibleError
from ansible.utils.display import Display
from ansible.utils.jsonrpc import JsonRpcServer
def do_fork():
class ConnectionProcess(object):
'''
Does the required double fork for a daemon process. Based on
http://code.activestate.com/recipes/66012-fork-a-daemon-process-on-unix/
The connection process wraps around a Connection object that manages
the connection to a remote device that persists over the playbook
'''
try:
pid = os.fork()
if pid > 0:
return pid
# This is done as a 'good practice' for daemons, but we need to keep the cwd
# leaving it here as a note that we KNOW its good practice but are not doing it on purpose.
# os.chdir("/")
os.setsid()
os.umask(0)
try:
pid = os.fork()
if pid > 0:
sys.exit(0)
if C.DEFAULT_LOG_PATH != '':
out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else:
out_file = open('/dev/null', 'ab+')
err_file = open('/dev/null', 'ab+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
os.close(sys.stdin.fileno())
return pid
except OSError as e:
sys.exit(1)
except OSError as e:
sys.exit(1)
class Server():
def __init__(self, socket_path, play_context):
self.socket_path = socket_path
def __init__(self, fd, play_context, socket_path, original_path):
self.play_context = play_context
self.socket_path = socket_path
self.original_path = original_path
display.display(
'creating new control socket for host %s:%s as user %s' %
(play_context.remote_addr, play_context.port, play_context.remote_user),
log_only=True
)
display.display('control socket path is %s' % socket_path, log_only=True)
display.display('current working directory is %s' % os.getcwd(), log_only=True)
self.fd = fd
self.exception = None
self._start_time = datetime.datetime.now()
self.srv = JsonRpcServer()
self.sock = None
display.display("using connection plugin %s" % self.play_context.connection, log_only=True)
def start(self):
try:
messages = list()
result = {}
self.connection = connection_loader.get(play_context.connection, play_context, sys.stdin)
self.connection._connect()
messages.append('control socket path is %s' % self.socket_path)
if not self.connection.connected:
raise AnsibleConnectionFailure('unable to connect to remote host %s' % self._play_context.remote_addr)
# If this is a relative path (~ gets expanded later) then plug the
# key's path on to the directory we originally came from, so we can
# find it now that our cwd is /
if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/':
self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file)
connection_time = datetime.datetime.now() - self._start_time
display.display('connection established to %s in %s' % (play_context.remote_addr, connection_time), log_only=True)
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null')
self.connection._connect()
self.srv.register(self.connection)
messages.append('connection to remote device started successfully')
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.bind(self.socket_path)
self.socket.listen(1)
display.display('local socket is set to listening', log_only=True)
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.bind(self.socket_path)
self.sock.listen(1)
messages.append('local domain socket listeners started successfully')
except Exception as exc:
result['error'] = to_text(exc)
result['exception'] = traceback.format_exc()
finally:
result['messages'] = messages
self.fd.write(json.dumps(result))
self.fd.close()
def run(self):
try:
@ -129,53 +89,36 @@ class Server():
signal.signal(signal.SIGTERM, self.handler)
signal.alarm(C.PERSISTENT_CONNECT_TIMEOUT)
(s, addr) = self.socket.accept()
display.display('incoming request accepted on persistent socket', log_only=True)
self.exception = None
(s, addr) = self.sock.accept()
signal.alarm(0)
signal.signal(signal.SIGALRM, self.command_timeout)
while True:
data = recv_data(s)
if not data:
break
signal.signal(signal.SIGALRM, self.command_timeout)
signal.alarm(self.play_context.timeout)
op = to_text(data.split(b':')[0])
display.display('socket operation is %s' % op, log_only=True)
method = getattr(self, 'do_%s' % op, None)
rc = 255
stdout = stderr = ''
if not method:
stderr = 'Invalid action specified'
else:
rc, stdout, stderr = method(data)
signal.alarm(self.connection._play_context.timeout)
resp = self.srv.handle_request(data)
signal.alarm(0)
display.display('socket operation completed with rc %s' % rc, log_only=True)
send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
send_data(s, to_bytes(resp))
s.close()
except Exception as e:
# socket.accept() will raise EINTR if the socket.close() is called
if hasattr(e, 'errno'):
if e.errno != errno.EINTR:
display.display(traceback.format_exc(), log_only=True)
self.exception = traceback.format_exc()
else:
self.exception = traceback.format_exc()
finally:
# when done, close the connection properly and cleanup
# the socket file so it can be recreated
self.shutdown()
end_time = datetime.datetime.now()
delta = end_time - self._start_time
display.display('shutdown local socket, connection was active for %s secs' % delta, log_only=True)
def connect_timeout(self, signum, frame):
display.display('persistent connection idle timeout triggered, timeout value is %s secs' % C.PERSISTENT_CONNECT_TIMEOUT, log_only=True)
@ -190,25 +133,25 @@ class Server():
self.shutdown()
def shutdown(self):
display.display('shutdown persistent connection requested', log_only=True)
""" Shuts down the local domain socket
"""
if not os.path.exists(self.socket_path):
display.display('persistent connection is not active', log_only=True)
return
try:
if self.socket:
display.display('closing local listener', log_only=True)
self.socket.close()
if self.sock:
self.sock.close()
if self.connection:
display.display('closing the connection', log_only=True)
self.connection.close()
except:
pass
finally:
if os.path.exists(self.socket_path):
display.display('removing the local control socket', log_only=True)
os.remove(self.socket_path)
setattr(self.connection, '_socket_path', None)
setattr(self.connection, '_connected', False)
display.display('shutdown complete', log_only=True)
@ -262,6 +205,13 @@ def communicate(sock, data):
def main():
""" Called to initiate the connect to the remote device
"""
rc = 0
result = {}
messages = list()
socket_path = None
# Need stdin as a byte stream
if PY3:
stdin = sys.stdin.buffer
@ -270,36 +220,39 @@ def main():
try:
# read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
cur_line = stdin.readline()
init_data = b''
while cur_line.strip() != b'#END_INIT#':
if cur_line == b'':
raise Exception("EOF found before init data was complete")
init_data += cur_line
cur_line = stdin.readline()
if PY3:
pc_data = cPickle.loads(init_data, encoding='bytes')
else:
pc_data = cPickle.loads(init_data)
pc = PlayContext()
pc.deserialize(pc_data)
play_context = PlayContext()
play_context.deserialize(pc_data)
except Exception as e:
# FIXME: better error message/handling/logging
sys.stderr.write(traceback.format_exc())
sys.exit("FAIL: %s" % e)
rc = 1
result.update({
'error': to_text(e),
'exception': traceback.format_exc()
})
if rc == 0:
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(pc.remote_addr, pc.port, pc.remote_user, pc.connection)
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection)
# create the persistent connection dir if need be and create the paths
# which we will be using later
tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
makedirs_safe(tmp_path)
lock_path = unfrackpath("%s/.ansible_pc_lock" % tmp_path)
socket_path = unfrackpath(cp % dict(directory=tmp_path))
@ -308,78 +261,50 @@ def main():
fcntl.lockf(lock_fd, fcntl.LOCK_EX)
if not os.path.exists(socket_path):
pid = do_fork()
messages.append('local domain socket does not exist, starting it')
original_path = os.getcwd()
r, w = os.pipe()
pid = fork_process()
if pid == 0:
rc = 0
try:
server = Server(socket_path, pc)
except AnsibleConnectionFailure as exc:
display.display('connecting to host %s returned an error' % pc.remote_addr, log_only=True)
display.display(str(exc), log_only=True)
rc = 1
os.close(r)
wfd = os.fdopen(w, 'w')
process = ConnectionProcess(wfd, play_context, socket_path, original_path)
process.start()
except Exception as exc:
display.display('failed to create control socket for host %s' % pc.remote_addr, log_only=True)
display.display(traceback.format_exc(), log_only=True)
messages.append(traceback.format_exc())
rc = 1
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
if rc == 0:
server.run()
sys.exit(rc)
else:
display.display('re-using existing socket for %s@%s:%s' % (pc.remote_user, pc.remote_addr, pc.port), log_only=True)
fcntl.lockf(lock_fd, fcntl.LOCK_UN)
os.close(lock_fd)
timeout = pc.timeout
while bool(timeout):
if os.path.exists(socket_path):
display.vvvv('connected to local socket in %s' % (pc.timeout - timeout), pc.remote_addr)
break
time.sleep(1)
timeout -= 1
else:
raise AnsibleConnectionFailure('timeout waiting for local socket', pc.remote_addr)
# now connect to the daemon process
# FIXME: if the socket file existed but the daemonized process was killed,
# the connection will timeout here. Need to make this more resilient.
while True:
data = stdin.readline()
if data == b'':
break
if data.strip() == b'':
continue
if rc == 0:
process.run()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sys.exit(rc)
connect_retry_timeout = C.PERSISTENT_CONNECT_RETRY_TIMEOUT
while bool(connect_retry_timeout):
try:
sock.connect(socket_path)
break
except socket.error:
time.sleep(1)
connect_retry_timeout -= 1
else:
display.display('connect retry timeout expired, unable to connect to control socket', pc.remote_addr, pc.remote_user, log_only=True)
display.display('persistent_connect_retry_timeout is %s secs' % (C.PERSISTENT_CONNECT_RETRY_TIMEOUT), pc.remote_addr, pc.remote_user, log_only=True)
sys.stderr.write('failed to connect to control socket')
sys.exit(255)
os.close(w)
rfd = os.fdopen(r, 'r')
data = json.loads(rfd.read())
messages.extend(data.pop('messages'))
result.update(data)
# send the play_context back into the connection so the connection
# can handle any privilege escalation activities
pc_data = b'CONTEXT: %s' % init_data
communicate(sock, pc_data)
rc, stdout, stderr = communicate(sock, data.strip())
else:
messages.append('found existing local domain socket, using it!')
sys.stdout.write(to_native(stdout))
sys.stderr.write(to_native(stderr))
result.update({
'messages': messages,
'socket_path': socket_path
})
sock.close()
break
if 'exception' in result:
rc = 1
sys.stderr.write(json.dumps(result))
else:
rc = 0
sys.stdout.write(json.dumps(result))
sys.exit(rc)

@ -1,31 +1,21 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import pty
import time
import json
import subprocess
import traceback
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip
from ansible.executor.task_result import TaskResult
from ansible.module_utils.six import iteritems, string_types, binary_type
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils._text import to_text
from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task
@ -490,6 +480,8 @@ class TaskExecutor:
not getattr(self._connection, 'connected', False) or
self._play_context.remote_addr != self._connection._play_context.remote_addr):
self._connection = self._get_connection(variables=variables, templar=templar)
if getattr(self._connection, '_socket_path'):
variables['ansible_socket'] = self._connection._socket_path
# only template the vars if the connection actually implements set_host_overrides
# NB: this is expensive, and should be removed once connection-specific vars are being handled by play_context
sho_impl = getattr(type(self._connection), 'set_host_overrides', None)
@ -736,11 +728,6 @@ class TaskExecutor:
if isinstance(i, string_types) and i.startswith("ansible_") and i.endswith("_interpreter"):
variables[i] = delegated_vars[i]
# if using persistent paramiko connections (or the action has set the FORCE_PERSISTENT_CONNECTION attribute to True),
# then we use the persistent connection plugion. Otherwise load the requested connection plugin
if C.USE_PERSISTENT_CONNECTIONS or getattr(self, 'FORCE_PERSISTENT_CONNECTION', False):
conn_type = 'persistent'
else:
conn_type = self._play_context.connection
connection = self._shared_loader_obj.connection_loader.get(conn_type, self._play_context, self._new_stdin)
@ -749,6 +736,13 @@ class TaskExecutor:
self._play_context.set_options_from_plugin(connection)
if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)):
display.vvvv('attempting to start connection', host=self._play_context.remote_addr)
display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr)
socket_path = self._start_connection()
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(connection, '_socket_path', socket_path)
return connection
def _get_action_handler(self, connection, templar):
@ -780,3 +774,42 @@ class TaskExecutor:
raise AnsibleError("the handler '%s' was not found" % handler_name)
return handler
def _start_connection(self):
'''
Starts the persistent connection
'''
master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)
# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')
(stdout, stderr) = p.communicate()
stdin.close()
if p.returncode == 0:
result = json.loads(stdout)
else:
result = json.loads(stderr)
if 'messages' in result:
for msg in result.get('messages'):
display.vvvv('%s' % msg, host=self._play_context.remote_addr)
if 'error' in result:
if self._play_context.verbosity > 2:
msg = "The full traceback is:\n" + result['exception']
display.display(result['exception'], color=C.COLOR_ERROR)
raise AnsibleError(result['error'])
return result['socket_path']

@ -27,6 +27,7 @@
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import json
import socket
import struct
import traceback
@ -35,6 +36,7 @@ import uuid
from functools import partial
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.six import iteritems
def send_data(s, data):
@ -61,23 +63,14 @@ def recv_data(s):
def exec_command(module, command):
connection = Connection(module._socket_path)
try:
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sf.connect(module._socket_path)
data = "EXEC: %s" % command
send_data(sf, to_bytes(data.strip()))
rc = int(recv_data(sf), 10)
stdout = recv_data(sf)
stderr = recv_data(sf)
except socket.error as e:
sf.close()
module.fail_json(msg='unable to connect to socket', err=to_native(e), exception=traceback.format_exc())
sf.close()
return rc, to_native(stdout, errors='surrogate_or_strict'), to_native(stderr, errors='surrogate_or_strict')
out = connection.exec_command(command)
except ConnectionError as exc:
code = getattr(exc, 'code', 1)
message = getattr(exc, 'err', exc)
return code, '', to_text(message, errors='surrogate_then_replace')
return 0, out, ''
def request_builder(method, *args, **kwargs):
@ -91,10 +84,19 @@ def request_builder(method, *args, **kwargs):
return req
class ConnectionError(Exception):
def __init__(self, message, *args, **kwargs):
super(ConnectionError, self).__init__(message)
for k, v in iteritems(kwargs):
setattr(self, k, v)
class Connection:
def __init__(self, module):
self._module = module
def __init__(self, socket_path):
assert socket_path is not None, 'socket_path must be a value'
self.socket_path = socket_path
def __getattr__(self, name):
try:
@ -116,30 +118,40 @@ class Connection:
req = request_builder(name, *args, **kwargs)
reqid = req['id']
if not self._module._socket_path:
self._module.fail_json(msg='provider support not available for this host')
if not os.path.exists(self._module._socket_path):
self._module.fail_json(msg='provider socket does not exist, is the provider running?')
if not os.path.exists(self.socket_path):
raise ConnectionError('socket_path does not exist or cannot be found')
try:
data = self._module.jsonify(req)
rc, out, err = exec_command(self._module, data)
data = json.dumps(req)
out = self.send(data)
response = json.loads(out)
except socket.error as e:
self._module.fail_json(msg='unable to connect to socket', err=to_native(e),
exception=traceback.format_exc())
try:
response = self._module.from_json(to_text(out, errors='surrogate_then_replace'))
except ValueError as exc:
self._module.fail_json(msg=to_text(exc, errors='surrogate_then_replace'))
raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc())
if response['id'] != reqid:
self._module.fail_json(msg='invalid id received')
raise ConnectionError('invalid json-rpc id received')
if 'error' in response:
msg = response['error'].get('data') or response['error']['message']
self._module.fail_json(msg=to_text(msg, errors='surrogate_then_replace'))
err = response.get('error')
msg = err.get('data') or err['message']
code = err['code']
raise ConnectionError(to_text(msg, errors='surrogate_then_replace'), code=code)
return response['result']
def send(self, data):
try:
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sf.connect(self.socket_path)
send_data(sf, to_bytes(data))
response = recv_data(sf)
except socket.error as e:
sf.close()
raise ConnectionError('unable to connect to socket', err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc())
sf.close()
return to_text(response, errors='surrogate_or_strict')

@ -27,6 +27,7 @@
#
from contextlib import contextmanager
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.connection import exec_command
try:
@ -38,7 +39,7 @@ NS_MAP = {'nc': "urn:ietf:params:xml:ns:netconf:base:1.0"}
def send_request(module, obj, check_rc=True, ignore_warning=True):
request = tostring(obj)
request = to_text(tostring(obj), errors='surrogate_or_strict')
rc, out, err = exec_command(module, request)
if rc != 0 and check_rc:
error_root = fromstring(err)
@ -59,7 +60,7 @@ def send_request(module, obj, check_rc=True, ignore_warning=True):
else:
module.fail_json(msg=str(err))
return warnings
return fromstring(out)
return fromstring(to_bytes(out, errors='surrogate_or_strict'))
def children(root, iterable):

@ -91,33 +91,14 @@ def fail_if_missing(module, found, service, msg=''):
module.fail_json(msg='Could not find the requested service %s: %s' % (service, msg))
def daemonize(module, cmd):
def fork_process():
'''
Execute a command while detaching as a daemon, returns rc, stdout, and stderr.
:arg module: is an AnsibleModule object, used for it's utility methods
:arg cmd: is a list or string representing the command and options to run
This is complex because daemonization is hard for people.
What we do is daemonize a part of this module, the daemon runs the command,
picks up the return code and output, and returns it to the main process.
This function performs the double fork process to detach from the
parent process and execute.
'''
# init some vars
chunk = 4096 # FIXME: pass in as arg?
errors = 'surrogate_or_strict'
# start it!
try:
pipe = os.pipe()
pid = os.fork()
except OSError:
module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc())
# we don't do any locking as this should be a unique module/process
if pid == 0:
os.close(pipe[0])
# Set stdin/stdout/stderr to /dev/null
fd = os.open(os.devnull, os.O_RDWR)
@ -140,7 +121,7 @@ def daemonize(module, cmd):
# get new process session and detach
sid = os.setsid()
if sid == -1:
module.fail_json(msg="Unable to detach session while daemonizing")
raise Exception("Unable to detach session while daemonizing")
# avoid possible problems with cwd being removed
os.chdir("/")
@ -149,6 +130,38 @@ def daemonize(module, cmd):
if pid > 0:
os._exit(0)
return pid
def daemonize(module, cmd):
'''
Execute a command while detaching as a daemon, returns rc, stdout, and stderr.
:arg module: is an AnsibleModule object, used for it's utility methods
:arg cmd: is a list or string representing the command and options to run
This is complex because daemonization is hard for people.
What we do is daemonize a part of this module, the daemon runs the command,
picks up the return code and output, and returns it to the main process.
'''
# init some vars
chunk = 4096 # FIXME: pass in as arg?
errors = 'surrogate_or_strict'
# start it!
try:
pipe = os.pipe()
pid = fork_process()
except OSError:
module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc())
except Exception as exc:
module.fail_json(msg=to_text(exc), exception=traceback.format_exc())
# we don't do any locking as this should be a unique module/process
if pid == 0:
os.close(pipe[0])
# if command is string deal with py2 vs py3 conversions for shlex
if not isinstance(cmd, list):
if PY2:

@ -427,6 +427,8 @@ class PlayContext(Base):
# if the final connection type is local, reset the remote_user value to that of the currently logged in user
# this ensures any become settings are obeyed correctly
# we store original in 'connection_user' for use of network/other modules that fallback to it as login user
# connection_user to be deprecated once connection=local is removed for
# network modules
if new_info.connection == 'local':
if not new_info.connection_user:
new_info.connection_user = new_info.remote_user

@ -36,6 +36,7 @@ from ansible.module_utils.json_utils import _filter_non_json_lines
from ansible.module_utils.six import binary_type, string_types, text_type, iteritems, with_metaclass
from ansible.module_utils.six.moves import shlex_quote
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils.connection import Connection
from ansible.parsing.utils.jsonify import jsonify
from ansible.release import __version__
from ansible.utils.unsafe_proxy import wrap_var
@ -604,6 +605,8 @@ class ActionBase(with_metaclass(ABCMeta, object)):
module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS
# give the module the socket for persistent connections
module_args['_ansible_socket'] = getattr(self._connection, 'socket_path')
if not module_args['_ansible_socket']:
module_args['_ansible_socket'] = task_vars.get('ansible_socket')
# make sure all commands use the designated shell executable
@ -818,6 +821,7 @@ class ActionBase(with_metaclass(ABCMeta, object)):
same_user = self._play_context.become_user == self._play_context.remote_user
if sudoable and self._play_context.become and (allow_same_user or not same_user):
display.debug("_low_level_execute_command(): using become for this command")
if self._connection.transport != 'network_cli' and self._play_context.become_method != 'enable':
cmd = self._play_context.make_become_cmd(cmd, executable=executable)
if self._connection.allow_executable:

@ -40,16 +40,10 @@ class ActionModule(_ActionModule):
provider = load_provider(eos_provider_spec, self._task.args)
transport = provider['transport'] or 'cli'
if self._play_context.connection != 'local' and transport == 'cli':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr)
if transport == 'cli':
if self._play_context.connection == 'local':
pc = copy.deepcopy(self._play_context)
pc.connection = 'network_cli'
pc.network_os = 'eos'
@ -60,6 +54,8 @@ class ActionModule(_ActionModule):
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc.become = provider['authorize'] or False
if pc.become:
pc.become_method = 'enable'
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
@ -72,14 +68,6 @@ class ActionModule(_ActionModule):
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while '(config' in str(out):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
else:

@ -38,13 +38,7 @@ class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
if self._play_context.connection == 'local':
provider = load_provider(ios_provider_spec, self._task.args)
pc = copy.deepcopy(self._play_context)
@ -57,6 +51,8 @@ class ActionModule(_ActionModule):
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
pc.become = provider['authorize'] or False
if pc.become:
pc.become_method = 'enable'
pc.become_pass = provider['auth_pass']
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
@ -69,14 +65,6 @@ class ActionModule(_ActionModule):
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
if self._play_context.become_method == 'enable':

@ -38,13 +38,7 @@ class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
if self._play_context.connection == 'local':
provider = load_provider(iosxr_provider_spec, self._task.args)
pc = copy.deepcopy(self._play_context)
@ -66,14 +60,6 @@ class ActionModule(_ActionModule):
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
result = super(ActionModule, self).run(tmp, task_vars)

@ -38,14 +38,6 @@ except ImportError:
class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
module = module_loader._load_module_source(self._task.action, module_loader.find_plugin(self._task.action))
if not getattr(module, 'USE_PERSISTENT_CONNECTION', False):
@ -72,6 +64,8 @@ class ActionModule(_ActionModule):
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
if self._play_context.connection == 'local':
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run()

@ -37,13 +37,6 @@ except ImportError:
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
play_context = copy.deepcopy(self._play_context)
play_context.network_os = self._get_network_os(task_vars)
@ -74,6 +67,7 @@ class ActionModule(ActionBase):
play_context.become = self.provider['authorize'] or False
play_context.become_pass = self.provider['auth_pass']
if self._play_context.connection == 'local':
socket_path = self._start_connection(play_context)
task_vars['ansible_socket'] = socket_path

@ -40,16 +40,10 @@ class ActionModule(_ActionModule):
provider = load_provider(nxos_provider_spec, self._task.args)
transport = provider['transport'] or 'cli'
if self._play_context.connection != 'local' and transport == 'cli':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
display.vvvv('connection transport is %s' % transport, self._play_context.remote_addr)
if transport == 'cli':
if self._play_context.connection == 'local':
pc = copy.deepcopy(self._play_context)
pc.connection = 'network_cli'
pc.network_os = 'nxos'
@ -60,6 +54,7 @@ class ActionModule(_ActionModule):
pc.private_key_file = provider['ssh_keyfile'] or self._play_context.private_key_file
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run()
@ -69,14 +64,6 @@ class ActionModule(_ActionModule):
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith(')#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
else:

@ -37,13 +37,6 @@ except ImportError:
class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
if self._play_context.connection != 'local':
return dict(
failed=True,
msg='invalid connection specified, expected connection=local, '
'got %s' % self._play_context.connection
)
provider = load_provider(vyos_provider_spec, self._task.args)
pc = copy.deepcopy(self._play_context)
@ -57,6 +50,8 @@ class ActionModule(_ActionModule):
pc.timeout = int(provider['timeout'] or C.PERSISTENT_COMMAND_TIMEOUT)
display.vvv('using connection plugin %s' % pc.connection, pc.remote_addr)
if self._play_context.connection == 'local':
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
socket_path = connection.run()
@ -66,14 +61,6 @@ class ActionModule(_ActionModule):
'msg': 'unable to open shell. Please see: ' +
'https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell'}
# make sure we are in the right cli context which should be
# enable mode and not config module
rc, out, err = connection.exec_command('prompt()')
while str(out).strip().endswith('#'):
display.vvvv('wrong context, sending exit to device', self._play_context.remote_addr)
connection.exec_command('exit')
rc, out, err = connection.exec_command('prompt()')
task_vars['ansible_socket'] = socket_path
result = super(ActionModule, self).run(tmp, task_vars)

@ -1,21 +1,7 @@
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
# (c) 2017, Peter Sprygada <psprygad@redhat.com>
# (c) 2017 Ansible Project
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
@ -69,6 +55,11 @@ class ConnectionBase(AnsiblePlugin):
module_implementation_preferences = ('',)
allow_executable = True
# the following control whether or not the connection supports the
# persistent connection framework or not
supports_persistence = False
force_persistence = False
def __init__(self, play_context, new_stdin, *args, **kwargs):
super(ConnectionBase, self).__init__()
@ -88,6 +79,8 @@ class ConnectionBase(AnsiblePlugin):
self.prompt = None
self._connected = False
self._socket_path = None
# load the shell plugin for this action/connection
if play_context.shell:
shell_type = play_context.shell
@ -110,6 +103,11 @@ class ConnectionBase(AnsiblePlugin):
'''Read-only property holding whether the connection to the remote host is active or closed.'''
return self._connected
@property
def socket_path(self):
'''Read-only property holding the connection socket path for this remote host'''
return self._socket_path
def _become_method_supported(self):
''' Checks if the current class supports this privilege escalation method '''

@ -71,15 +71,14 @@ DOCUMENTATION = """
import os
import logging
import json
from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure, AnsibleError
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE
from ansible.plugins.loader import netconf_loader
from ansible.plugins.connection import ConnectionBase, ensure_connect
from ansible.utils.jsonrpc import Rpc
from ansible.plugins.connection.local import Connection as LocalConnection
try:
from ncclient import manager
@ -98,11 +97,12 @@ except ImportError:
logging.getLogger('ncclient').setLevel(logging.INFO)
class Connection(Rpc, ConnectionBase):
class Connection(ConnectionBase):
"""NetConf connections"""
transport = 'netconf'
has_pipelining = False
force_persistence = True
def __init__(self, play_context, new_stdin, *args, **kwargs):
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
@ -113,18 +113,50 @@ class Connection(Rpc, ConnectionBase):
self._manager = None
self._connected = False
self._local = LocalConnection(play_context, new_stdin, *args, **kwargs)
def exec_command(self, request, in_data=None, sudoable=True):
"""Sends the request to the node and returns the reply
The method accepts two forms of request. The first form is as a byte
string that represents xml string be send over netconf session.
The second form is a json-rpc (2.0) byte string.
"""
if self._manager:
# to_ele operates on native strings
request = to_ele(to_native(request, errors='surrogate_or_strict'))
if request is None:
return 'unable to parse request'
try:
reply = self._manager.rpc(request)
except RPCError as exc:
return to_xml(exc.xml)
return reply.data_xml
else:
return self._local.exec_command(request, in_data, sudoable)
def put_file(self, in_path, out_path):
"""Transfer a file from local to remote"""
return self._local.put_file(in_path, out_path)
def fetch_file(self, in_path, out_path):
"""Fetch a file from remote to local"""
return self._local.fetch_file(in_path, out_path)
def _connect(self):
super(Connection, self)._connect()
display.display('ssh connection done, stating ncclient', log_only=True)
display.display('ssh connection done, starting ncclient', log_only=True)
self.allow_agent = True
allow_agent = True
if self._play_context.password is not None:
self.allow_agent = False
allow_agent = False
self.key_filename = None
key_filename = None
if self._play_context.private_key_file:
self.key_filename = os.path.expanduser(self._play_context.private_key_file)
key_filename = os.path.expanduser(self._play_context.private_key_file)
network_os = self._play_context.network_os
@ -149,16 +181,18 @@ class Connection(Rpc, ConnectionBase):
port=self._play_context.port or 830,
username=self._play_context.remote_user,
password=self._play_context.password,
key_filename=str(self.key_filename),
key_filename=str(key_filename),
hostkey_verify=C.HOST_KEY_CHECKING,
look_for_keys=C.PARAMIKO_LOOK_FOR_KEYS,
allow_agent=self.allow_agent,
allow_agent=allow_agent,
timeout=self._play_context.timeout,
device_params={'name': network_os},
ssh_config=ssh_config
)
except SSHUnknownHostError as exc:
raise AnsibleConnectionFailure(str(exc))
except ImportError as exc:
raise AnsibleError("connection=netconf is not supported on {0}".format(network_os))
if not self._manager.connected:
return 1, b'', b'not connected'
@ -169,7 +203,6 @@ class Connection(Rpc, ConnectionBase):
self._netconf = netconf_loader.get(network_os, self)
if self._netconf:
self._rpc.add(self._netconf)
display.display('loaded netconf plugin for network_os %s' % network_os, log_only=True)
else:
display.display('unable to load netconf for network_os %s' % network_os)
@ -181,46 +214,3 @@ class Connection(Rpc, ConnectionBase):
self._manager.close_session()
self._connected = False
super(Connection, self).close()
@ensure_connect
def exec_command(self, request):
"""Sends the request to the node and returns the reply
The method accepts two forms of request. The first form is as a byte
string that represents xml string be send over netconf session.
The second form is a json-rpc (2.0) byte string.
"""
try:
obj = json.loads(to_text(request, errors='surrogate_or_strict'))
if 'jsonrpc' in obj:
if self._netconf:
out = self._exec_rpc(obj)
else:
out = self.internal_error("netconf plugin is not supported for network_os %s" % self._play_context.network_os)
return 0, to_bytes(out, errors='surrogate_or_strict'), b''
else:
err = self.invalid_request(obj)
return 1, b'', to_bytes(err, errors='surrogate_or_strict')
except (ValueError, TypeError):
# to_ele operates on native strings
request = to_native(request, errors='surrogate_or_strict')
req = to_ele(request)
if req is None:
return 1, b'', b'unable to parse request'
try:
reply = self._manager.rpc(req)
except RPCError as exc:
return 1, b'', to_bytes(to_xml(exc.xml), errors='surrogate_or_strict')
return 0, to_bytes(reply.data_xml, errors='surrogate_or_strict'), b''
def put_file(self, in_path, out_path):
"""Transfer a file from local to remote"""
pass
def fetch_file(self, in_path, out_path):
"""Fetch a file from remote to local"""
pass

@ -47,6 +47,7 @@ DOCUMENTATION = """
import json
import logging
import re
import os
import signal
import socket
import traceback
@ -57,9 +58,11 @@ from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six import BytesIO, binary_type
from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins.loader import cliconf_loader, terminal_loader
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
from ansible.utils.jsonrpc import Rpc
from ansible.plugins.loader import cliconf_loader, terminal_loader, connection_loader
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.connection.local import Connection as LocalConnection
from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoSshConnection
from ansible.utils.path import unfrackpath, makedirs_safe
try:
from __main__ import display
@ -68,31 +71,73 @@ except ImportError:
display = Display()
class Connection(Rpc, _Connection):
class Connection(ConnectionBase):
''' CLI (shell) SSH connections on Paramiko '''
transport = 'network_cli'
has_pipelining = True
force_persistence = True
def __init__(self, play_context, new_stdin, *args, **kwargs):
super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs)
self._terminal = None
self._cliconf = None
self._shell = None
self.ssh = None
self._ssh_shell = None
self._matched_prompt = None
self._matched_pattern = None
self._last_response = None
self._history = list()
self._play_context = play_context
if play_context.verbosity > 3:
self._local = LocalConnection(play_context, new_stdin, *args, **kwargs)
self._terminal = None
self._cliconf = None
if self._play_context.verbosity > 3:
logging.getLogger('paramiko').setLevel(logging.DEBUG)
# reconstruct the socket_path and set instance values accordingly
self._update_connection_state()
def __getattr__(self, name):
try:
return self.__dict__[name]
except KeyError:
if name.startswith('_'):
raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
return getattr(self._cliconf, name)
def exec_command(self, cmd, in_data=None, sudoable=True):
# this try..except block is just to handle the transition to supporting
# network_cli as a toplevel connection. Once connection=local is gone,
# this block can be removed as well and all calls passed directly to
# the local connection
if self._ssh_shell:
try:
cmd = json.loads(to_text(cmd, errors='surrogate_or_strict'))
kwargs = {'command': to_bytes(cmd['command'], errors='surrogate_or_strict')}
for key in ('prompts', 'answer', 'send_only'):
if key in cmd:
kwargs[key] = to_bytes(cmd[key], errors='surrogate_or_strict')
return self.send(**kwargs)
except ValueError:
cmd = to_bytes(cmd, errors='surrogate_or_strict')
return self.send(command=cmd)
else:
return self._local.exec_command(cmd, in_data, sudoable)
def put_file(self, in_path, out_path):
return self._local.put_file(in_path, out_path)
def fetch_file(self, in_path, out_path):
return self._local.fetch_file(in_path, out_path)
def update_play_context(self, play_context):
"""Updates the play context information for the connection"""
display.display('updating play_context for connection', log_only=True)
display.vvvv('updating play_context for connection', host=self._play_context.remote_addr)
if self._play_context.become is False and play_context.become is True:
auth_pass = play_context.become_pass
@ -104,17 +149,22 @@ class Connection(Rpc, _Connection):
self._play_context = play_context
def _connect(self):
"""Connections to the device and sets the terminal type"""
'''
Connects to the remote device and starts the terminal
'''
if self.connected:
return
if self._play_context.password and not self._play_context.private_key_file:
C.PARAMIKO_LOOK_FOR_KEYS = False
super(Connection, self)._connect()
ssh = ParamikoSshConnection(self._play_context, '/dev/null')._connect()
self.ssh = ssh.ssh
display.display('ssh connection done, setting terminal', log_only=True)
display.vvvv('ssh connection done, setting terminal', host=self._play_context.remote_addr)
self._shell = self.ssh.invoke_shell()
self._shell.settimeout(self._play_context.timeout)
self._ssh_shell = self.ssh.invoke_shell()
self._ssh_shell.settimeout(self._play_context.timeout)
network_os = self._play_context.network_os
if not network_os:
@ -127,53 +177,83 @@ class Connection(Rpc, _Connection):
if not self._terminal:
raise AnsibleConnectionFailure('network os %s is not supported' % network_os)
display.display('loaded terminal plugin for network_os %s' % network_os, log_only=True)
display.vvvv('loaded terminal plugin for network_os %s' % network_os, host=self._play_context.remote_addr)
self._cliconf = cliconf_loader.get(network_os, self)
if self._cliconf:
self._rpc.add(self._cliconf)
display.display('loaded cliconf plugin for network_os %s' % network_os, log_only=True)
display.vvvv('loaded cliconf plugin for network_os %s' % network_os, host=self._play_context.remote_addr)
else:
display.display('unable to load cliconf for network_os %s' % network_os)
display.vvvv('unable to load cliconf for network_os %s' % network_os)
self.receive()
display.display('firing event: on_open_shell()', log_only=True)
display.vvvv('firing event: on_open_shell()', host=self._play_context.remote_addr)
self._terminal.on_open_shell()
if getattr(self._play_context, 'become', None):
display.display('firing event: on_authorize', log_only=True)
if self._play_context.become and self._play_context.become_method == 'enable':
display.vvvv('firing event: on_authorize', host=self._play_context.remote_addr)
auth_pass = self._play_context.become_pass
self._terminal.on_authorize(passwd=auth_pass)
display.vvvv('ssh connection has completed successfully', host=self._play_context.remote_addr)
self._connected = True
display.display('ssh connection has completed successfully', log_only=True)
def close(self):
"""Close the active connection to the device
"""
display.display("closing ssh connection to device", log_only=True)
if self._shell:
display.display("firing event: on_close_shell()", log_only=True)
self._terminal.on_close_shell()
self._shell.close()
self._shell = None
display.display("cli session is now closed", log_only=True)
return self
super(Connection, self).close()
def _update_connection_state(self):
'''
Reconstruct the connection socket_path and check if it exists
If the socket path exists then the connection is active and set
both the _socket_path value to the path and the _connected value
to True. If the socket path doesn't exist, leave the socket path
value to None and the _connected value to False
'''
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(self._play_context.remote_addr, self._play_context.port, self._play_context.remote_user)
tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
socket_path = unfrackpath(cp % dict(directory=tmp_path))
if os.path.exists(socket_path):
self._connected = True
self._socket_path = socket_path
def reset(self):
'''
Reset the connection
'''
if self._socket_path:
display.vvvv('resetting persistent connection for socket_path %s' % self._socket_path, host=self._play_context.remote_addr)
self.shutdown()
def close(self):
'''
Close the active connection to the device
'''
# only close the connection if its connected.
if self._connected:
display.debug("closing ssh connection to device")
if self._ssh_shell:
display.debug("firing event: on_close_shell()")
self._terminal.on_close_shell()
self._ssh_shell.close()
self._ssh_shell = None
display.debug("cli session is now closed")
self._connected = False
display.display("ssh connection has been closed successfully", log_only=True)
display.debug("ssh connection has been closed successfully")
def receive(self, command=None, prompts=None, answer=None):
"""Handles receiving of output from command"""
'''
Handles receiving of output from command
'''
recv = BytesIO()
handled = False
self._matched_prompt = None
while True:
data = self._shell.recv(256)
data = self._ssh_shell.recv(256)
recv.write(data)
offset = recv.tell() - 256 if recv.tell() > 256 else 0
@ -190,25 +270,30 @@ class Connection(Rpc, _Connection):
return self._sanitize(resp, command)
def send(self, command, prompts=None, answer=None, send_only=False):
"""Sends the command to the device in the opened shell"""
'''
Sends the command to the device in the opened shell
'''
try:
self._history.append(command)
self._shell.sendall(b'%s\r' % command)
self._ssh_shell.sendall(b'%s\r' % command)
if send_only:
return
return self.receive(command, prompts, answer)
response = self.receive(command, prompts, answer)
return to_text(response, errors='surrogate_or_strict')
except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True)
display.vvvv(traceback.format_exc(), host=self._play_context.remote_addr)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
def _strip(self, data):
"""Removes ANSI codes from device response"""
'''
Removes ANSI codes from device response
'''
for regex in self._terminal.ansi_re:
data = regex.sub(b'', data)
return data
def _handle_prompt(self, resp, prompts, answer):
"""
'''
Matches the command prompt and responds
:arg resp: Byte string containing the raw response from the remote
@ -216,17 +301,19 @@ class Connection(Rpc, _Connection):
:arg answer: Byte string to send back to the remote if we find a prompt.
A carriage return is automatically appended to this string.
:returns: True if a prompt was found in ``resp``. False otherwise
"""
'''
prompts = [re.compile(r, re.I) for r in prompts]
for regex in prompts:
match = regex.search(resp)
if match:
self._shell.sendall(b'%s\r' % answer)
self._ssh_shell.sendall(b'%s\r' % answer)
return True
return False
def _sanitize(self, resp, command=None):
"""Removes elements from the response before returning to the caller"""
'''
Removes elements from the response before returning to the caller
'''
cleaned = []
for line in resp.splitlines():
if (command and line.strip() == command.strip()) or self._matched_prompt.strip() in line:
@ -235,7 +322,8 @@ class Connection(Rpc, _Connection):
return b'\n'.join(cleaned).strip()
def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt"""
'''Searches the buffered response for a matching command prompt
'''
errored_response = None
is_error_message = False
for regex in self._terminal.terminal_stderr_re:
@ -264,64 +352,3 @@ class Connection(Rpc, _Connection):
raise AnsibleConnectionFailure(errored_response)
return False
def alarm_handler(self, signum, frame):
"""Alarm handler raised in case of command timeout """
display.display('closing shell due to sigalarm', log_only=True)
self.close()
def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output
The method accepts three forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The
second form is as a utf8 JSON byte string with additional keywords.
The third form is a json-rpc (2.0)
Keywords supported for cmd:
:command: the command string to execute
:prompt: the expected prompt generated by executing command.
This can be a string or a list of strings
:answer: the string to respond to the prompt with
:sendonly: bool to disable waiting for response
:arg cmd: the byte string that represents the command to be executed
which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return
code is an integer and stdout and stderr are byte strings
"""
try:
obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
except (ValueError, TypeError):
obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
if 'prompt' in obj:
if isinstance(obj['prompt'], binary_type):
# Prompt was a string
obj['prompt'] = [obj['prompt']]
elif not isinstance(obj['prompt'], Sequence):
# Convert nonstrings into byte strings (to_bytes(5) => b'5')
if obj['prompt'] is not None:
obj['prompt'] = [to_bytes(obj['prompt'], errors='surrogate_or_strict')]
else:
# Prompt was a Sequence of strings. Make sure they're byte strings
obj['prompt'] = [to_bytes(p, errors='surrogate_or_strict') for p in obj['prompt'] if p is not None]
if 'jsonrpc' in obj:
if self._cliconf:
out = self._exec_rpc(obj)
else:
out = self.internal_error("cliconf is not supported for network_os %s" % self._play_context.network_os)
return 0, to_bytes(out, errors='surrogate_or_strict'), b''
if obj['command'] == b'prompt()':
return 0, self._matched_prompt, b''
try:
if not signal.getsignal(signal.SIGALRM):
signal.signal(signal.SIGALRM, self.alarm_handler)
signal.alarm(self._play_context.timeout)
out = self.send(obj['command'], obj.get('prompt'), obj.get('answer'), obj.get('sendonly'))
signal.alarm(0)
return 0, out, b''
except (AnsibleConnectionFailure, ValueError) as exc:
return 1, b'', to_bytes(exc)

@ -100,6 +100,7 @@ with warnings.catch_warnings():
class MyAddPolicy(object):
"""
Based on AutoAddPolicy in paramiko so we can determine when keys are added
and also prompt for input.
Policy for automatically adding the hostname and new host key to the
@ -114,8 +115,13 @@ class MyAddPolicy(object):
if all((C.HOST_KEY_CHECKING, not C.PARAMIKO_HOST_KEY_AUTO_ADD)):
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
if C.USE_PERSISTENT_CONNECTIONS:
raise AnsibleConnectionFailure('rejected %s host key for host %s: %s' % (key.get_name(), hostname, hexlify(key.get_fingerprint())))
# don't print the prompt string since the user cannot respond
# to the question anyway
raise AnsibleError(AUTHENTICITY_MSG[1:92] % (hostname, ktype, fingerprint))
self.connection.connection_lock()
@ -125,9 +131,6 @@ class MyAddPolicy(object):
# clear out any premature input on sys.stdin
tcflush(sys.stdin, TCIFLUSH)
fingerprint = hexlify(key.get_fingerprint())
ktype = key.get_name()
inp = input(AUTHENTICITY_MSG % (hostname, ktype, fingerprint))
sys.stdin = old_stdin

@ -1,4 +1,4 @@
# (c) 2017 Red Hat Inc.
# 2017 Red Hat Inc.
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
@ -13,15 +13,19 @@ DOCUMENTATION = """
- This is a helper plugin to allow making other connections persistent.
version_added: "2.3"
"""
import re
import os
import sys
import pty
import json
import subprocess
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.six.moves import cPickle
from ansible import constants as C
from ansible.plugins.loader import connection_loader
from ansible.plugins.connection import ConnectionBase
from ansible.module_utils._text import to_text
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import Connection as SocketConnection
from ansible.errors import AnsibleError
try:
from __main__ import display
@ -40,8 +44,38 @@ class Connection(ConnectionBase):
self._connected = True
return self
def _do_it(self, action):
def exec_command(self, cmd, in_data=None, sudoable=True):
display.vvvv('exec_command(), socket_path=%s' % self.socket_path, host=self._play_context.remote_addr)
connection = SocketConnection(self.socket_path)
out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable)
return 0, out, ''
def put_file(self, in_path, out_path):
pass
def fetch_file(self, in_path, out_path):
pass
def close(self):
self._connected = False
def run(self):
"""Returns the path of the persistent connection socket.
Attempts to ensure (within playcontext.timeout seconds) that the
socket path exists. If the path exists (or the timeout has expired),
returns the socket path.
"""
display.vvvv('starting connection from persistent connection plugin', host=self._play_context.remote_addr)
socket_path = self._start_connection()
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(self, '_socket_path', socket_path)
return socket_path
def _start_connection(self):
'''
Starts the persistent connection
'''
master, slave = pty.openpty()
p = subprocess.Popen(["ansible-connection"], stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdin = os.fdopen(master, 'wb', 0)
@ -56,40 +90,23 @@ class Connection(ConnectionBase):
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action))
stdin.write(b'\n\n')
(stdout, stderr) = p.communicate()
stdin.close()
return (p.returncode, stdout, stderr)
def exec_command(self, cmd, in_data=None, sudoable=True):
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
return self._do_it('EXEC: ' + cmd)
def put_file(self, in_path, out_path):
super(Connection, self).put_file(in_path, out_path)
self._do_it('PUT: %s %s' % (in_path, out_path))
def fetch_file(self, in_path, out_path):
super(Connection, self).fetch_file(in_path, out_path)
self._do_it('FETCH: %s %s' % (in_path, out_path))
def close(self):
self._connected = False
if p.returncode == 0:
result = json.loads(to_text(stdout, errors='surrogate_then_replace'))
else:
result = json.loads(to_text(stderr, errors='surrogate_then_replace'))
def run(self):
"""Returns the path of the persistent connection socket.
if 'messages' in result:
for msg in result.get('messages'):
display.vvvv('%s' % msg, host=self._play_context.remote_addr)
Attempts to ensure (within playcontext.timeout seconds) that the
socket path exists. If the path exists (or the timeout has expired),
returns the socket path.
"""
socket_path = None
rc, out, err = self._do_it('RUN:')
match = re.search(br"#SOCKET_PATH#: (\S+)", out)
if match:
socket_path = to_text(match.group(1).strip(), errors='surrogate_or_strict')
if 'error' in result:
if self._play_context.verbosity > 2:
msg = "The full traceback is:\n" + result['exception']
display.display(result['exception'], color=C.COLOR_ERROR)
raise AnsibleError(result['error'])
return socket_path
return result['socket_path']

@ -56,20 +56,12 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
self._connection = connection
def _exec_cli_command(self, cmd, check_rc=True):
"""
Executes a CLI command on the device
:arg cmd: Byte string consisting of the command to execute
:kwarg check_rc: If True, the default, raise an
:exc:`AnsibleConnectionFailure` if the return code from the
command is nonzero
:returns: A tuple of return code, stdout, and stderr from running the
command. stdout and stderr are both byte strings.
"""
rc, out, err = self._connection.exec_command(cmd)
if check_rc and rc != 0:
raise AnsibleConnectionFailure(err)
return rc, out, err
'''
Executes the CLI command on the remote device and returns the output
:arg cmd: Byte string command to be executed
'''
return self._connection.exec_command(cmd)
def _get_prompt(self):
"""
@ -77,9 +69,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
:returns: A byte string of the prompt
"""
for cmd in (b'\n', b'prompt()'):
rc, out, err = self._exec_cli_command(cmd)
return out
self._exec_cli_command(b'\n')
return self._connection._matched_prompt
def on_open_shell(self):
"""Called after the SSH session is established

@ -36,21 +36,21 @@ except ImportError:
class TerminalModule(TerminalBase):
terminal_stdout_re = [
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"),
re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$|%"),
]
terminal_stderr_re = [
re.compile(r"unknown command"),
re.compile(r"syntax error,")
re.compile(br"unknown command"),
re.compile(br"syntax error,")
]
def on_open_shell(self):
try:
prompt = self._get_prompt()
if prompt.strip().endswith('%'):
if prompt.strip().endswith(b'%'):
display.vvv('starting cli', self._connection._play_context.remote_addr)
self._exec_cli_command('cli')
for c in ['set cli timestamp disable', 'set cli screen-length 0', 'set cli screen-width 1024']:
for c in (b'set cli timestamp disable', b'set cli screen-length 0', b'set cli screen-width 1024'):
self._exec_cli_command(c)
except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to set terminal parameters')

@ -1,28 +1,16 @@
#
# (c) 2016 Red Hat Inc.
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# (c) 2017, Peter Sprygada <psprygad@redhat.com>
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import json
import traceback
from ansible import constants as C
from ansible.module_utils._text import to_text
try:
from __main__ import display
except ImportError:
@ -30,13 +18,13 @@ except ImportError:
display = Display()
class Rpc:
class JsonRpcServer(object):
_objects = set()
def __init__(self, *args, **kwargs):
self._rpc = set()
super(Rpc, self).__init__(*args, **kwargs)
def handle_request(self, request):
request = json.loads(to_text(request, errors='surrogate_then_replace'))
def _exec_rpc(self, request):
method = request.get('method')
if method.startswith('rpc.') or method.startswith('_'):
@ -45,6 +33,7 @@ class Rpc:
params = request.get('params')
setattr(self, '_identifier', request.get('id'))
args = []
kwargs = {}
@ -54,7 +43,12 @@ class Rpc:
kwargs = params
rpc_method = None
for obj in self._rpc:
if method in ('shutdown', 'reset'):
rpc_method = getattr(self, 'shutdown')
else:
for obj in self._objects:
rpc_method = getattr(obj, method, None)
if rpc_method:
break
@ -66,7 +60,7 @@ class Rpc:
try:
result = rpc_method(*args, **kwargs)
except Exception as exc:
display.display(traceback.format_exc(), log_only=True)
display.vvv(traceback.format_exc())
error = self.internal_error(data=to_text(exc, errors='surrogate_then_replace'))
response = json.dumps(error)
else:
@ -78,8 +72,12 @@ class Rpc:
response = json.dumps(response)
delattr(self, '_identifier')
return response
def register(self, obj):
self._objects.add(obj)
def header(self):
return {'jsonrpc': '2.0', 'id': self._identifier}

@ -405,6 +405,7 @@ class TestActionBase(unittest.TestCase):
mock_connection = MagicMock()
mock_connection.build_module_command.side_effect = build_module_command
mock_connection.socket_path = None
mock_connection._shell.get_remote_filename.return_value = 'copy.py'
mock_connection._shell.join_path.side_effect = os.path.join

@ -37,6 +37,7 @@ from ansible.plugins.connection.paramiko_ssh import Connection as ParamikoConnec
from ansible.plugins.connection.ssh import Connection as SSHConnection
from ansible.plugins.connection.docker import Connection as DockerConnection
# from ansible.plugins.connection.winrm import Connection as WinRmConnection
from ansible.plugins.connection.netconf import Connection as NetconfConnection
from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection
@ -140,7 +141,9 @@ class TestConnectionBaseClass(unittest.TestCase):
def test_network_cli_connection_module(self):
self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection)
self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), ParamikoConnection)
def test_netconf_connection_module(self):
self.assertIsInstance(NetconfConnection(self.play_context, self.in_stream), NetconfConnection)
def test_check_password_prompt(self):
local = (

@ -69,9 +69,9 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin)
mock_manager = MagicMock(name='self._manager.connect')
type(mock_manager).session_id = PropertyMock(return_value='123456789')
netconf.manager.connect.return_value = mock_manager
mock_manager = MagicMock()
mock_manager.session_id = '123456789'
netconf.manager.connect = MagicMock(return_value=mock_manager)
conn._play_context.network_os = 'default'
rc, out, err = conn._connect()
@ -88,22 +88,16 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin)
conn._connected = True
mock_manager = MagicMock(name='self._manager')
mock_reply = MagicMock(name='reply')
type(mock_reply).data_xml = PropertyMock(return_value='<test/>')
mock_manager = MagicMock(name='self._manager')
mock_manager.rpc.return_value = mock_reply
conn._manager = mock_manager
rc, out, err = conn.exec_command('<test/>')
netconf.to_ele.assert_called_with('<test/>')
out = conn.exec_command('<test/>')
self.assertEqual(0, rc)
self.assertEqual(b'<test/>', out)
self.assertEqual(b'', err)
self.assertEqual('<test/>', out)
def test_netconf_exec_command_invalid_request(self):
pc = PlayContext()
@ -112,10 +106,11 @@ class TestNetconfConnectionClass(unittest.TestCase):
conn = netconf.Connection(pc, new_stdin)
conn._connected = True
mock_manager = MagicMock(name='self._manager')
conn._manager = mock_manager
netconf.to_ele.return_value = None
rc, out, err = conn.exec_command('test string')
out = conn.exec_command('test string')
self.assertEqual(1, rc)
self.assertEqual(b'', out)
self.assertEqual(b'unable to parse request', err)
self.assertEqual('unable to parse request', out)

@ -35,7 +35,7 @@ from ansible.plugins.connection import network_cli
class TestConnectionClass(unittest.TestCase):
@patch("ansible.plugins.connection.network_cli._Connection._connect")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__connect_error(self, mocked_super):
pc = PlayContext()
new_stdin = StringIO()
@ -47,7 +47,7 @@ class TestConnectionClass(unittest.TestCase):
pc.network_os = None
self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli._Connection._connect")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__invalid_os(self, mocked_super):
pc = PlayContext()
new_stdin = StringIO()
@ -60,7 +60,7 @@ class TestConnectionClass(unittest.TestCase):
self.assertRaises(AnsibleConnectionFailure, conn._connect)
@patch("ansible.plugins.connection.network_cli.terminal_loader")
@patch("ansible.plugins.connection.network_cli._Connection._connect")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli__connect(self, mocked_super, mocked_terminal_loader):
pc = PlayContext()
new_stdin = StringIO()
@ -70,22 +70,21 @@ class TestConnectionClass(unittest.TestCase):
conn.ssh = MagicMock()
conn.receive = MagicMock()
mock_terminal = MagicMock()
conn._terminal = mock_terminal
conn._terminal = MagicMock()
conn._connect()
self.assertTrue(conn._terminal.on_open_shell.called)
self.assertFalse(conn._terminal.on_authorize.called)
conn._play_context.become = True
conn._play_context.become_method = 'enable'
conn._play_context.become_pass = 'password'
conn._connected = False
conn._connect()
conn._terminal.on_authorize.assert_called_with(passwd='password')
@patch("ansible.plugins.connection.network_cli._Connection.close")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection.close")
def test_network_cli_close(self, mocked_super):
pc = PlayContext()
new_stdin = StringIO()
@ -93,20 +92,14 @@ class TestConnectionClass(unittest.TestCase):
terminal = MagicMock(supports_multiplexing=False)
conn._terminal = terminal
conn.close()
conn._shell = MagicMock()
conn._ssh_shell = MagicMock()
conn._connected = True
conn.close()
self.assertTrue(terminal.on_close_shell.called)
self.assertIsNone(conn._ssh_shell)
terminal.supports_multiplexing = True
conn.close()
self.assertIsNone(conn._shell)
@patch("ansible.plugins.connection.network_cli._Connection._connect")
@patch("ansible.plugins.connection.network_cli.ParamikoSshConnection._connect")
def test_network_cli_exec_command(self, mocked_super):
pc = PlayContext()
new_stdin = StringIO()
@ -114,23 +107,17 @@ class TestConnectionClass(unittest.TestCase):
mock_send = MagicMock(return_value=b'command response')
conn.send = mock_send
conn._ssh_shell = MagicMock()
# test sending a single command and converting to dict
rc, out, err = conn.exec_command('command')
out = conn.exec_command('command')
self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None)
mock_send.assert_called_with(command=b'command')
# test sending a json string
rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None)
conn._shell = MagicMock()
# test _shell already open
rc, out, err = conn.exec_command('command')
out = conn.exec_command(json.dumps({'command': 'command'}))
self.assertEqual(out, b'command response')
mock_send.assert_called_with(b'command', None, None, None)
mock_send.assert_called_with(command=b'command')
def test_network_cli_send(self):
pc = PlayContext()
@ -142,7 +129,7 @@ class TestConnectionClass(unittest.TestCase):
conn._terminal = mock__terminal
mock__shell = MagicMock()
conn._shell = mock__shell
conn._ssh_shell = mock__shell
response = b"""device#command
command response
@ -155,7 +142,7 @@ class TestConnectionClass(unittest.TestCase):
output = conn.send(b'command', None, None, None)
mock__shell.sendall.assert_called_with(b'command\r')
self.assertEqual(output, b'command response')
self.assertEqual(output, 'command response')
mock__shell.reset_mock()
mock__shell.recv.return_value = b"ERROR: error message device#"

Loading…
Cancel
Save