#!/usr/bin/env python # Copyright: (c) 2017, Ansible Project # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) __metaclass__ = type __requires__ = ['ansible'] try: import pkg_resources except Exception: pass import fcntl import hashlib import os import signal import socket import sys import traceback import errno import json from contextlib import contextmanager from ansible import constants as C from ansible.module_utils._text import to_bytes, to_native, to_text from ansible.module_utils.six import PY3 from ansible.module_utils.six.moves import cPickle, StringIO from ansible.module_utils.connection import Connection, ConnectionError, send_data, recv_data from ansible.module_utils.service import fork_process from ansible.playbook.play_context import PlayContext from ansible.plugins.loader import connection_loader from ansible.utils.path import unfrackpath, makedirs_safe from ansible.utils.display import Display from ansible.utils.jsonrpc import JsonRpcServer def read_stream(byte_stream): size = int(byte_stream.readline().strip()) data = byte_stream.read(size) if len(data) < size: raise Exception("EOF found before data was complete") data_hash = to_text(byte_stream.readline().strip()) if data_hash != hashlib.sha1(data).hexdigest(): raise Exception("Read {0} bytes, but data did not match checksum".format(size)) # restore escaped loose \r characters data = data.replace(br'\r', b'\r') return data @contextmanager def file_lock(lock_path): """ Uses contextmanager to create and release a file lock based on the given path. This allows us to create locks using `with file_lock()` to prevent deadlocks related to failure to unlock properly. """ lock_fd = os.open(lock_path, os.O_RDWR | os.O_CREAT, 0o600) fcntl.lockf(lock_fd, fcntl.LOCK_EX) yield fcntl.lockf(lock_fd, fcntl.LOCK_UN) os.close(lock_fd) class ConnectionProcess(object): ''' The connection process wraps around a Connection object that manages the connection to a remote device that persists over the playbook ''' def __init__(self, fd, play_context, socket_path, original_path, ansible_playbook_pid=None): self.play_context = play_context self.socket_path = socket_path self.original_path = original_path self.fd = fd self.exception = None self.srv = JsonRpcServer() self.sock = None self.connection = None self._ansible_playbook_pid = ansible_playbook_pid def start(self, variables): try: messages = list() result = {} messages.append('control socket path is %s' % self.socket_path) # If this is a relative path (~ gets expanded later) then plug the # key's path on to the directory we originally came from, so we can # find it now that our cwd is / if self.play_context.private_key_file and self.play_context.private_key_file[0] not in '~/': self.play_context.private_key_file = os.path.join(self.original_path, self.play_context.private_key_file) self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null', ansible_playbook_pid=self._ansible_playbook_pid) self.connection.set_options(var_options=variables) self.connection._connect() self.connection._socket_path = self.socket_path self.srv.register(self.connection) messages.extend(sys.stdout.getvalue().splitlines()) messages.append('connection to remote device started successfully') self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.bind(self.socket_path) self.sock.listen(1) messages.append('local domain socket listeners started successfully') except Exception as exc: result['error'] = to_text(exc) result['exception'] = traceback.format_exc() finally: result['messages'] = messages self.fd.write(json.dumps(result)) self.fd.close() def run(self): try: while self.connection.connected: signal.signal(signal.SIGALRM, self.connect_timeout) signal.signal(signal.SIGTERM, self.handler) signal.alarm(self.connection.get_option('persistent_connect_timeout')) self.exception = None (s, addr) = self.sock.accept() signal.alarm(0) signal.signal(signal.SIGALRM, self.command_timeout) while True: data = recv_data(s) if not data: break signal.alarm(self.connection._play_context.timeout) resp = self.srv.handle_request(data) signal.alarm(0) send_data(s, to_bytes(resp)) s.close() except Exception as e: # socket.accept() will raise EINTR if the socket.close() is called if hasattr(e, 'errno'): if e.errno != errno.EINTR: self.exception = traceback.format_exc() else: self.exception = traceback.format_exc() finally: # when done, close the connection properly and cleanup # the socket file so it can be recreated self.shutdown() def connect_timeout(self, signum, frame): display.display('persistent connection idle timeout triggered, timeout value is %s secs' % self.connection.get_option('persistent_connect_timeout'), log_only=True) self.shutdown() def command_timeout(self, signum, frame): display.display('command timeout triggered, timeout value is %s secs' % self.play_context.timeout, log_only=True) self.shutdown() def handler(self, signum, frame): display.display('signal handler called with signal %s' % signum, log_only=True) self.shutdown() def shutdown(self): """ Shuts down the local domain socket """ if os.path.exists(self.socket_path): try: if self.sock: self.sock.close() if self.connection: self.connection.close() except: pass finally: if os.path.exists(self.socket_path): os.remove(self.socket_path) setattr(self.connection, '_socket_path', None) setattr(self.connection, '_connected', False) display.display('shutdown complete', log_only=True) def main(): """ Called to initiate the connect to the remote device """ rc = 0 result = {} messages = list() socket_path = None # Need stdin as a byte stream if PY3: stdin = sys.stdin.buffer else: stdin = sys.stdin # Note: update the below log capture code after Display.display() is refactored. saved_stdout = sys.stdout sys.stdout = StringIO() try: # read the play context data via stdin, which means depickling it vars_data = read_stream(stdin) init_data = read_stream(stdin) if PY3: pc_data = cPickle.loads(init_data, encoding='bytes') variables = cPickle.loads(vars_data, encoding='bytes') else: pc_data = cPickle.loads(init_data) variables = cPickle.loads(vars_data) play_context = PlayContext() play_context.deserialize(pc_data) display.verbosity = play_context.verbosity except Exception as e: rc = 1 result.update({ 'error': to_text(e), 'exception': traceback.format_exc() }) if rc == 0: ssh = connection_loader.get('ssh', class_only=True) ansible_playbook_pid = sys.argv[1] cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user, play_context.connection, ansible_playbook_pid) # create the persistent connection dir if need be and create the paths # which we will be using later tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) makedirs_safe(tmp_path) lock_path = unfrackpath("%s/.ansible_pc_lock_%s" % (tmp_path, play_context.remote_addr)) socket_path = unfrackpath(cp % dict(directory=tmp_path)) with file_lock(lock_path): if not os.path.exists(socket_path): messages.append('local domain socket does not exist, starting it') original_path = os.getcwd() r, w = os.pipe() pid = fork_process() if pid == 0: try: os.close(r) wfd = os.fdopen(w, 'w') process = ConnectionProcess(wfd, play_context, socket_path, original_path, ansible_playbook_pid) process.start(variables) except Exception: messages.append(traceback.format_exc()) rc = 1 if rc == 0: process.run() else: process.shutdown() sys.exit(rc) else: os.close(w) rfd = os.fdopen(r, 'r') data = json.loads(rfd.read()) messages.extend(data.pop('messages')) result.update(data) else: messages.append('found existing local domain socket, using it!') conn = Connection(socket_path) pc_data = to_text(init_data) try: messages.extend(conn.update_play_context(pc_data)) except Exception as exc: # Only network_cli has update_play context, so missing this is # not fatal e.g. netconf if isinstance(exc, ConnectionError) and getattr(exc, 'code', None) == -32601: pass else: result.update({ 'error': to_text(exc), 'exception': traceback.format_exc() }) messages.append(sys.stdout.getvalue()) result.update({ 'messages': messages, 'socket_path': socket_path }) sys.stdout = saved_stdout if 'exception' in result: rc = 1 sys.stderr.write(json.dumps(result)) else: rc = 0 sys.stdout.write(json.dumps(result)) sys.exit(rc) if __name__ == '__main__': display = Display() main()