diff --git a/changelogs/fragments/persistent_loading_2.yml b/changelogs/fragments/persistent_loading_2.yml new file mode 100644 index 00000000000..794a0d87097 --- /dev/null +++ b/changelogs/fragments/persistent_loading_2.yml @@ -0,0 +1,2 @@ +bugfixes: + - pass correct loading context to persistent connections other than local diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 2eb16e76491..dcdb3114dae 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -24,7 +24,7 @@ from ansible.module_utils._text import to_text, to_native from ansible.module_utils.connection import write_to_file_descriptor from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task -from ansible.plugins.loader import become_loader +from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader from ansible.template import Templar from ansible.utils.listify import listify_lookup_plugin_terms from ansible.utils.unsafe_proxy import UnsafeProxy, wrap_var @@ -915,7 +915,7 @@ class TaskExecutor: display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr) options = self._get_persistent_connection_options(connection, variables, templar) - socket_path = self._start_connection(options) + socket_path = start_connection(self._play_context, options) display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr) setattr(connection, '_socket_path', socket_path) @@ -1034,71 +1034,81 @@ class TaskExecutor: return handler - def _start_connection(self, variables): - ''' - Starts the persistent connection - ''' - candidate_paths = [C.ANSIBLE_CONNECTION_PATH or os.path.dirname(sys.argv[0])] - candidate_paths.extend(os.environ['PATH'].split(os.pathsep)) - for dirname in candidate_paths: - ansible_connection = os.path.join(dirname, 'ansible-connection') - if os.path.isfile(ansible_connection): - break - else: - raise AnsibleError("Unable to find location of 'ansible-connection'. " - "Please set or check the value of ANSIBLE_CONNECTION_PATH") - - python = sys.executable - master, slave = pty.openpty() - p = subprocess.Popen( - [python, ansible_connection, to_text(os.getppid())], - stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - os.close(slave) - - # We need to set the pty into noncanonical mode. This ensures that we - # can receive lines longer than 4095 characters (plus newline) without - # truncating. - old = termios.tcgetattr(master) - new = termios.tcgetattr(master) - new[3] = new[3] & ~termios.ICANON +def start_connection(play_context, variables): + ''' + Starts the persistent connection + ''' + candidate_paths = [C.ANSIBLE_CONNECTION_PATH or os.path.dirname(sys.argv[0])] + candidate_paths.extend(os.environ['PATH'].split(os.pathsep)) + for dirname in candidate_paths: + ansible_connection = os.path.join(dirname, 'ansible-connection') + if os.path.isfile(ansible_connection): + break + else: + raise AnsibleError("Unable to find location of 'ansible-connection'. " + "Please set or check the value of ANSIBLE_CONNECTION_PATH") + + env = os.environ.copy() + env.update({ + 'ANSIBLE_BECOME_PLUGINS': become_loader.print_paths(), + 'ANSIBLE_CLICONF_PLUGINS': cliconf_loader.print_paths(), + 'ANSIBLE_CONNECTION_PLUGINS': connection_loader.print_paths(), + 'ANSIBLE_HTTPAPI_PLUGINS': httpapi_loader.print_paths(), + 'ANSIBLE_NETCONF_PLUGINS': netconf_loader.print_paths(), + 'ANSIBLE_TERMINAL_PLUGINS': terminal_loader.print_paths(), + }) + python = sys.executable + master, slave = pty.openpty() + p = subprocess.Popen( + [python, ansible_connection, to_text(os.getppid())], + stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + os.close(slave) + + # We need to set the pty into noncanonical mode. This ensures that we + # can receive lines longer than 4095 characters (plus newline) without + # truncating. + old = termios.tcgetattr(master) + new = termios.tcgetattr(master) + new[3] = new[3] & ~termios.ICANON + + try: + termios.tcsetattr(master, termios.TCSANOW, new) + write_to_file_descriptor(master, variables) + write_to_file_descriptor(master, play_context.serialize()) + + (stdout, stderr) = p.communicate() + finally: + termios.tcsetattr(master, termios.TCSANOW, old) + os.close(master) + + if p.returncode == 0: + result = json.loads(to_text(stdout, errors='surrogate_then_replace')) + else: try: - termios.tcsetattr(master, termios.TCSANOW, new) - write_to_file_descriptor(master, variables) - write_to_file_descriptor(master, self._play_context.serialize()) + result = json.loads(to_text(stderr, errors='surrogate_then_replace')) + except getattr(json.decoder, 'JSONDecodeError', ValueError): + # JSONDecodeError only available on Python 3.5+ + result = {'error': to_text(stderr, errors='surrogate_then_replace')} + + if 'messages' in result: + for level, message in result['messages']: + if level == 'log': + display.display(message, log_only=True) + elif level in ('debug', 'v', 'vv', 'vvv', 'vvvv', 'vvvvv', 'vvvvvv'): + getattr(display, level)(message, host=play_context.remote_addr) + else: + if hasattr(display, level): + getattr(display, level)(message) + else: + display.vvvv(message, host=play_context.remote_addr) - (stdout, stderr) = p.communicate() - finally: - termios.tcsetattr(master, termios.TCSANOW, old) - os.close(master) + if 'error' in result: + if play_context.verbosity > 2: + if result.get('exception'): + msg = "The full traceback is:\n" + result['exception'] + display.display(msg, color=C.COLOR_ERROR) + raise AnsibleError(result['error']) - if p.returncode == 0: - result = json.loads(to_text(stdout, errors='surrogate_then_replace')) - else: - try: - result = json.loads(to_text(stderr, errors='surrogate_then_replace')) - except getattr(json.decoder, 'JSONDecodeError', ValueError): - # JSONDecodeError only available on Python 3.5+ - result = {'error': to_text(stderr, errors='surrogate_then_replace')} - - if 'messages' in result: - for level, message in result['messages']: - if level == 'log': - display.display(message, log_only=True) - elif level in ('debug', 'v', 'vv', 'vvv', 'vvvv', 'vvvvv', 'vvvvvv'): - getattr(display, level)(message, host=self._play_context.remote_addr) - else: - if hasattr(display, level): - getattr(display, level)(message) - else: - display.vvvv(message, host=self._play_context.remote_addr) - - if 'error' in result: - if self._play_context.verbosity > 2: - if result.get('exception'): - msg = "The full traceback is:\n" + result['exception'] - display.display(msg, color=C.COLOR_ERROR) - raise AnsibleError(result['error']) - - return result['socket_path'] + return result['socket_path'] diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py index 28eee8b0a81..fc4f5a44d06 100644 --- a/lib/ansible/plugins/connection/persistent.py +++ b/lib/ansible/plugins/connection/persistent.py @@ -29,19 +29,9 @@ options: vars: - name: ansible_command_timeout """ -import os -import pty -import json -import subprocess -import sys -import termios - -from ansible import constants as C -from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader +from ansible.executor.task_executor import start_connection from ansible.plugins.connection import ConnectionBase -from ansible.module_utils._text import to_text -from ansible.module_utils.connection import Connection as SocketConnection, write_to_file_descriptor -from ansible.errors import AnsibleError +from ansible.module_utils.connection import Connection as SocketConnection from ansible.utils.display import Display display = Display() @@ -80,85 +70,8 @@ class Connection(ConnectionBase): returns the socket path. """ display.vvvv('starting connection from persistent connection plugin', host=self._play_context.remote_addr) - socket_path = self._start_connection() + variables = {'ansible_command_timeout': self.get_option('persistent_command_timeout')} + socket_path = start_connection(self._play_context, variables) display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr) setattr(self, '_socket_path', socket_path) return socket_path - - def _start_connection(self): - ''' - Starts the persistent connection - ''' - candidate_paths = [C.ANSIBLE_CONNECTION_PATH or os.path.dirname(sys.argv[0])] - candidate_paths.extend(os.environ['PATH'].split(os.pathsep)) - for dirname in candidate_paths: - ansible_connection = os.path.join(dirname, 'ansible-connection') - if os.path.isfile(ansible_connection): - break - else: - raise AnsibleError("Unable to find location of 'ansible-connection'. " - "Please set or check the value of ANSIBLE_CONNECTION_PATH") - - env = os.environ.copy() - env.update({ - 'ANSIBLE_BECOME_PLUGINS': become_loader.print_paths(), - 'ANSIBLE_CLICONF_PLUGINS': cliconf_loader.print_paths(), - 'ANSIBLE_CONNECTION_PLUGINS': connection_loader.print_paths(), - 'ANSIBLE_HTTPAPI_PLUGINS': httpapi_loader.print_paths(), - 'ANSIBLE_NETCONF_PLUGINS': netconf_loader.print_paths(), - 'ANSIBLE_TERMINAL_PLUGINS': terminal_loader.print_paths(), - }) - python = sys.executable - master, slave = pty.openpty() - p = subprocess.Popen( - [python, ansible_connection, to_text(os.getppid())], - stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env - ) - os.close(slave) - - # We need to set the pty into noncanonical mode. This ensures that we - # can receive lines longer than 4095 characters (plus newline) without - # truncating. - old = termios.tcgetattr(master) - new = termios.tcgetattr(master) - new[3] = new[3] & ~termios.ICANON - - try: - termios.tcsetattr(master, termios.TCSANOW, new) - write_to_file_descriptor(master, {'ansible_command_timeout': self.get_option('persistent_command_timeout')}) - write_to_file_descriptor(master, self._play_context.serialize()) - - (stdout, stderr) = p.communicate() - finally: - termios.tcsetattr(master, termios.TCSANOW, old) - os.close(master) - - if p.returncode == 0: - result = json.loads(to_text(stdout, errors='surrogate_then_replace')) - else: - try: - result = json.loads(to_text(stderr, errors='surrogate_then_replace')) - except getattr(json.decoder, 'JSONDecodeError', ValueError): - # JSONDecodeError only available on Python 3.5+ - result = {'error': to_text(stderr, errors='surrogate_then_replace')} - - if 'messages' in result: - for level, message in result['messages']: - if level == 'log': - display.display(message, log_only=True) - elif level in ('debug', 'v', 'vv', 'vvv', 'vvvv', 'vvvvv', 'vvvvvv'): - getattr(display, level)(message, host=self._play_context.remote_addr) - else: - if hasattr(display, level): - getattr(display, level)(message) - else: - display.vvvv(message, host=self._play_context.remote_addr) - - if 'error' in result: - if self._play_context.verbosity > 2: - if result.get('exception'): - msg = "The full traceback is:\n" + result['exception'] - display.display(msg, color=C.COLOR_ERROR) - raise AnsibleError(result['error']) - - return result['socket_path']