diff --git a/changelogs/fragments/76737-paramiko-rsa-sha2.yml b/changelogs/fragments/76737-paramiko-rsa-sha2.yml new file mode 100644 index 00000000000..696576f8c35 --- /dev/null +++ b/changelogs/fragments/76737-paramiko-rsa-sha2.yml @@ -0,0 +1,5 @@ +bugfixes: +- paramiko - Add a new option to allow paramiko >= 2.9 to easily work with + all devices now that rsa-sha2 support was added to paramiko, which + prevented communication with numerous platforms. + (https://github.com/ansible/ansible/issues/76737) diff --git a/lib/ansible/module_utils/compat/paramiko.py b/lib/ansible/module_utils/compat/paramiko.py index 3a508cae757..85478eae2fc 100644 --- a/lib/ansible/module_utils/compat/paramiko.py +++ b/lib/ansible/module_utils/compat/paramiko.py @@ -6,11 +6,14 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type import types +import warnings PARAMIKO_IMPORT_ERR = None try: - import paramiko + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='Blowfish has been deprecated', category=UserWarning) + import paramiko # paramiko and gssapi are incompatible and raise AttributeError not ImportError # When running in FIPS mode, cryptography raises InternalError # https://bugzilla.redhat.com/show_bug.cgi?id=1778939 diff --git a/lib/ansible/plugins/connection/paramiko_ssh.py b/lib/ansible/plugins/connection/paramiko_ssh.py index bcac61ffc5a..1b31d42b3c7 100644 --- a/lib/ansible/plugins/connection/paramiko_ssh.py +++ b/lib/ansible/plugins/connection/paramiko_ssh.py @@ -58,6 +58,20 @@ DOCUMENTATION = """ - name: ansible_paramiko_pass - name: ansible_paramiko_password version_added: '2.5' + use_rsa_sha2_algorithms: + description: + - Whether or not to enable RSA SHA2 algorithms for pubkeys and hostkeys + - On paramiko versions older than 2.9, this only affects hostkeys + - For behavior matching paramiko<2.9 set this to C(False) + vars: + - name: ansible_paramiko_use_rsa_sha2_algorithms + ini: + - {key: use_rsa_sha2_algorithms, section: paramiko_connection} + env: + - {name: ANSIBLE_PARAMIKO_USE_RSA_SHA2_ALGORITHMS} + default: True + type: boolean + version_added: '2.14' host_key_auto_add: description: 'Automatically add host keys' env: [{name: ANSIBLE_PARAMIKO_HOST_KEY_AUTO_ADD}] @@ -374,6 +388,18 @@ class Connection(ConnectionBase): ssh = paramiko.SSHClient() + # Set pubkey and hostkey algorithms to disable, the only manipulation allowed currently + # is keeping or omitting rsa-sha2 algorithms + paramiko_preferred_pubkeys = getattr(paramiko.Transport, '_preferred_pubkeys', ()) + paramiko_preferred_hostkeys = getattr(paramiko.Transport, '_preferred_keys', ()) + use_rsa_sha2_algorithms = self.get_option('use_rsa_sha2_algorithms') + disabled_algorithms = {} + if not use_rsa_sha2_algorithms: + if paramiko_preferred_pubkeys: + disabled_algorithms['pubkeys'] = tuple(a for a in paramiko_preferred_pubkeys if 'rsa-sha2' in a) + if paramiko_preferred_hostkeys: + disabled_algorithms['keys'] = tuple(a for a in paramiko_preferred_hostkeys if 'rsa-sha2' in a) + # override paramiko's default logger name if self._log_channel is not None: ssh.set_log_channel(self._log_channel) @@ -423,7 +449,8 @@ class Connection(ConnectionBase): password=conn_password, timeout=self._play_context.timeout, port=port, - **ssh_connect_kwargs + disabled_algorithms=disabled_algorithms, + **ssh_connect_kwargs, ) except paramiko.ssh_exception.BadHostKeyException as e: raise AnsibleConnectionFailure('host key mismatch for %s' % e.hostname) diff --git a/test/lib/ansible_test/_data/requirements/constraints.txt b/test/lib/ansible_test/_data/requirements/constraints.txt index edac1b9395e..a65c7b7ab21 100644 --- a/test/lib/ansible_test/_data/requirements/constraints.txt +++ b/test/lib/ansible_test/_data/requirements/constraints.txt @@ -1,7 +1,6 @@ # do not add a cryptography or pyopenssl constraint to this file, they require special handling, see get_cryptography_requirements in python_requirements.py # do not add a coverage constraint to this file, it is handled internally by ansible-test packaging < 21.0 ; python_version < '3.6' # packaging 21.0 requires Python 3.6 or newer -paramiko < 2.9.0 # paramiko 2.9.0+ requires changes to the paramiko_ssh connection plugin to work with older systems pywinrm >= 0.3.0 ; python_version < '3.11' # message encryption support pywinrm >= 0.4.3 ; python_version >= '3.11' # support for Python 3.11 pytest < 5.0.0, >= 4.5.0 ; python_version == '2.7' # pytest 5.0.0 and later will no longer support python 2.7 diff --git a/test/lib/ansible_test/_internal/host_profiles.py b/test/lib/ansible_test/_internal/host_profiles.py index 50b80193397..b97152e2431 100644 --- a/test/lib/ansible_test/_internal/host_profiles.py +++ b/test/lib/ansible_test/_internal/host_profiles.py @@ -478,6 +478,7 @@ class NetworkRemoteProfile(RemoteProfile[NetworkRemoteConfig]): ansible_port=connection.port, ansible_user=connection.username, ansible_ssh_private_key_file=core_ci.ssh_key.key, + ansible_paramiko_use_rsa_sha2_algorithms='no', ansible_network_os=f'{self.config.collection}.{self.config.platform}' if self.config.collection else self.config.platform, ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py index 8abcf8e898a..fef40810498 100644 --- a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py @@ -6,13 +6,20 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type -DOCUMENTATION = """author: Ansible Networking Team -connection: network_cli +DOCUMENTATION = """ +author: + - Ansible Networking Team (@ansible-network) +name: network_cli short_description: Use network_cli to run command on network appliances description: - This connection plugin provides a connection to remote devices over the SSH and implements a CLI shell. This connection plugin is typically used by network devices for sending and receiving CLi commands to network devices. +version_added: 1.0.0 +requirements: +- ansible-pylibssh if using I(ssh_type=libssh) +extends_documentation_fragment: +- ansible.netcommon.connection_persistent options: host: description: @@ -20,6 +27,7 @@ options: to. default: inventory_hostname vars: + - name: inventory_hostname - name: ansible_host port: type: int @@ -90,6 +98,31 @@ options: - name: ANSIBLE_BECOME vars: - name: ansible_become + become_errors: + type: str + description: + - This option determines how privilege escalation failures are handled when + I(become) is enabled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_become_errors + default: fail + choices: ["ignore", "warn", "fail"] + terminal_errors: + type: str + description: + - This option determines how failures while setting terminal parameters + are handled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_terminal_errors + default: fail + choices: ["ignore", "warn", "fail"] + version_added: 3.1.0 become_method: description: - This option allows the become method to be specified in for handling privilege @@ -118,34 +151,6 @@ options: key: host_key_auto_add env: - name: ANSIBLE_HOST_KEY_AUTO_ADD - persistent_connect_timeout: - type: int - description: - - Configures, in seconds, the amount of time to wait when trying to initially - establish a persistent connection. If this value expires before the connection - to the remote device is completed, the connection will fail. - default: 30 - ini: - - section: persistent_connection - key: connect_timeout - env: - - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT - vars: - - name: ansible_connect_timeout - persistent_command_timeout: - type: int - description: - - Configures, in seconds, the amount of time to wait for a command to return from - the remote device. If this timer is exceeded before the command returns, the - connection plugin will raise an exception and close. - default: 30 - ini: - - section: persistent_connection - key: command_timeout - env: - - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT - vars: - - name: ansible_command_timeout persistent_buffer_read_timeout: type: float description: @@ -161,23 +166,6 @@ options: - name: ANSIBLE_PERSISTENT_BUFFER_READ_TIMEOUT vars: - name: ansible_buffer_read_timeout - persistent_log_messages: - type: boolean - description: - - This flag will enable logging the command executed and response received from - target device in the ansible log file. For this option to work 'log_path' ansible - configuration option is required to be set to a file path with write access. - - Be sure to fully understand the security implications of enabling this option - as it could create a security vulnerability by logging sensitive information - in log file. - default: false - ini: - - section: persistent_connection - key: log_messages - env: - - name: ANSIBLE_PERSISTENT_LOG_MESSAGES - vars: - - name: ansible_persistent_log_messages terminal_stdout_re: type: list elements: dict @@ -204,6 +192,7 @@ options: - name: ansible_terminal_stderr_re terminal_initial_prompt: type: list + elements: string description: - A single regex pattern or a sequence of patterns to evaluate the expected prompt at the time of initial login to the remote host. @@ -211,6 +200,7 @@ options: - name: ansible_terminal_initial_prompt terminal_initial_answer: type: list + elements: string description: - The answer to reply with if the C(terminal_initial_prompt) is matched. The value can be a single answer or a list of answers for multiple terminal_initial_prompt. @@ -255,35 +245,90 @@ options: key: network_cli_retries vars: - name: ansible_network_cli_retries + ssh_type: + description: + - The python package that will be used by the C(network_cli) connection plugin to create a SSH connection to remote host. + - I(libssh) will use the ansible-pylibssh package, which needs to be installed in order to work. + - I(paramiko) will instead use the paramiko package to manage the SSH connection. + - I(auto) will use ansible-pylibssh if that package is installed, otherwise will fallback to paramiko. + default: auto + choices: ["libssh", "paramiko", "auto"] + env: + - name: ANSIBLE_NETWORK_CLI_SSH_TYPE + ini: + - section: persistent_connection + key: ssh_type + vars: + - name: ansible_network_cli_ssh_type + host_key_checking: + description: 'Set this to "False" if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host' + type: boolean + default: True + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: persistent_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + single_user_mode: + type: boolean + default: false + version_added: 2.0.0 + description: + - This option enables caching of data fetched from the target for re-use. + The cache is invalidated when the target device enters configuration mode. + - Applicable only for platforms where this has been implemented. + env: + - name: ANSIBLE_NETWORK_SINGLE_USER_MODE + vars: + - name: ansible_network_single_user_mode """ -from functools import wraps import getpass import json import logging -import re import os +import re import signal import socket import time import traceback +from functools import wraps from io import BytesIO -from ansible.errors import AnsibleConnectionFailure +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.basic import missing_required_lib from ansible.module_utils.six import PY3 from ansible.module_utils.six.moves import cPickle -from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( - to_list, -) -from ansible.module_utils._text import to_bytes, to_text from ansible.playbook.play_context import PlayContext -from ansible.plugins.connection import NetworkConnectionBase from ansible.plugins.loader import ( + cache_loader, cliconf_loader, - terminal_loader, connection_loader, + terminal_loader, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible_collections.ansible.netcommon.plugins.plugin_utils.connection_base import ( + NetworkConnectionBase, ) +try: + from scp import SCPClient + + HAS_SCP = True +except ImportError: + HAS_SCP = False + +HAS_PYLIBSSH = False + def ensure_connect(func): @wraps(func) @@ -301,7 +346,7 @@ class AnsibleCmdRespRecv(Exception): class Connection(NetworkConnectionBase): - """ CLI (shell) SSH connections on Paramiko """ + """CLI (shell) SSH connections on Paramiko""" transport = "ansible.netcommon.network_cli" has_pipelining = True @@ -319,17 +364,19 @@ class Connection(NetworkConnectionBase): self._history = list() self._command_response = None self._last_recv_window = None + self._cache = None self._terminal = None self.cliconf = None - self._paramiko_conn = None # Managing prompt context self._check_prompt = False + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + self._ssh_type_conn = None + self._ssh_type = None - if self._play_context.verbosity > 3: - logging.getLogger("paramiko").setLevel(logging.DEBUG) + self._single_user_mode = False if self._network_os: self._terminal = terminal_loader.get(self._network_os, self) @@ -368,24 +415,65 @@ class Connection(NetworkConnectionBase): self.queue_message("log", "network_os is set to %s" % self._network_os) @property - def paramiko_conn(self): - if self._paramiko_conn is None: - self._paramiko_conn = connection_loader.get( - "paramiko", self._play_context, "/dev/null" + def ssh_type(self): + if self._ssh_type is None: + self._ssh_type = self.get_option("ssh_type") + self.queue_message( + "vvvv", "ssh type is set to %s" % self._ssh_type ) - self._paramiko_conn.set_options( - direct={ - "look_for_keys": not bool( - self._play_context.password - and not self._play_context.private_key_file + # Support autodetection of supported library + if self._ssh_type == "auto": + self.queue_message("vvvv", "autodetecting ssh_type") + if HAS_PYLIBSSH: + self._ssh_type = "libssh" + else: + self.queue_message( + "warning", + "ansible-pylibssh not installed, falling back to paramiko", ) - } + self._ssh_type = "paramiko" + self.queue_message( + "vvvv", "ssh type is now set to %s" % self._ssh_type + ) + + if self._ssh_type not in ["paramiko", "libssh"]: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + return self._ssh_type + + @property + def ssh_type_conn(self): + if self._ssh_type_conn is None: + if self.ssh_type == "libssh": + connection_plugin = "ansible.netcommon.libssh" + elif self.ssh_type == "paramiko": + # NOTE: This MUST be paramiko or things will break + connection_plugin = "paramiko" + else: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + self._ssh_type_conn = connection_loader.get( + connection_plugin, self._play_context, "/dev/null" ) - return self._paramiko_conn + + return self._ssh_type_conn + + # To maintain backward compatibility + @property + def paramiko_conn(self): + return self.ssh_type_conn def _get_log_channel(self): name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) - name += "paramiko [%s]" % self._play_context.remote_addr + name += "%s [%s]" % (self.ssh_type, self._play_context.remote_addr) return name @ensure_connect @@ -427,6 +515,44 @@ class Connection(NetworkConnectionBase): else: return super(Connection, self).exec_command(cmd, in_data, sudoable) + def get_options(self, hostvars=None): + options = super(Connection, self).get_options(hostvars=hostvars) + options.update(self.ssh_type_conn.get_options(hostvars=hostvars)) + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(Connection, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + self.ssh_type_conn.set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + # Retain old look_for_keys behaviour, but only if not set + if not any( + [ + task_keys and ("look_for_keys" in task_keys), + var_options and ("look_for_keys" in var_options), + direct and ("look_for_keys" in direct), + ] + ): + look_for_keys = not bool( + self.get_option("password") + and not self.get_option("private_key_file") + ) + if not look_for_keys: + # This actually can't be overridden yet without changes in ansible-core + # TODO: Uncomment when appropriate + # self.queue_message( + # "warning", + # "Option look_for_keys has been implicitly set to {0} because " + # "it was not set explicitly. This is retained to maintain " + # "backwards compatibility with the old behavior. This behavior " + # "will be removed in some release after 2024-01-01".format( + # look_for_keys + # ), + # ) + self.ssh_type_conn.set_option("look_for_keys", look_for_keys) + def update_play_context(self, pc_data): """Updates the play context information for the connection""" pc_data = to_bytes(pc_data) @@ -441,19 +567,26 @@ class Connection(NetworkConnectionBase): if self._play_context.become ^ play_context.become: if play_context.become is True: auth_pass = play_context.become_pass - self._terminal.on_become(passwd=auth_pass) + self._on_become(become_pass=auth_pass) self.queue_message("vvvv", "authorizing connection") else: self._terminal.on_unbecome() self.queue_message("vvvv", "deauthorizing connection") self._play_context = play_context + if self._ssh_type_conn is not None: + # TODO: This works, but is not really ideal. We would rather use + # set_options, but then we need more custom handling in that + # method. + self._ssh_type_conn._play_context = play_context if hasattr(self, "reset_history"): self.reset_history() if hasattr(self, "disable_response_logging"): self.disable_response_logging() + self._single_user_mode = self.get_option("single_user_mode") + def set_check_prompt(self, task_uuid): self._check_prompt = task_uuid @@ -467,9 +600,18 @@ class Connection(NetworkConnectionBase): """ Connects to the remote device and starts the terminal """ + if self._play_context.verbosity > 3: + logging.getLogger(self.ssh_type).setLevel(logging.DEBUG) + + self.queue_message( + "vvvv", "invoked shell using ssh_type: %s" % self.ssh_type + ) + + self._single_user_mode = self.get_option("single_user_mode") + if not self.connected: - self.paramiko_conn._set_log_channel(self._get_log_channel()) - self.paramiko_conn.force_persistence = self.force_persistence + self.ssh_type_conn._set_log_channel(self._get_log_channel()) + self.ssh_type_conn.force_persistence = self.force_persistence command_timeout = self.get_option("persistent_command_timeout") max_pause = min( @@ -483,8 +625,10 @@ class Connection(NetworkConnectionBase): for attempt in range(retries + 1): try: - ssh = self.paramiko_conn._connect() + ssh = self.ssh_type_conn._connect() break + except AnsibleError: + raise except Exception as e: pause = 2 ** (attempt + 1) if attempt == retries or total_pause >= max_pause: @@ -493,8 +637,8 @@ class Connection(NetworkConnectionBase): ) else: msg = ( - u"network_cli_retry: attempt: %d, caught exception(%s), " - u"pausing for %d seconds" + "network_cli_retry: attempt: %d, caught exception(%s), " + "pausing for %d seconds" % ( attempt + 1, to_text(e, errors="surrogate_or_strict"), @@ -511,7 +655,8 @@ class Connection(NetworkConnectionBase): self._connected = True self._ssh_shell = ssh.ssh.invoke_shell() - self._ssh_shell.settimeout(command_timeout) + if self.ssh_type == "paramiko": + self._ssh_shell.settimeout(command_timeout) self.queue_message( "vvvv", @@ -544,10 +689,10 @@ class Connection(NetworkConnectionBase): if self._play_context.become: self.queue_message("vvvv", "firing event: on_become") auth_pass = self._play_context.become_pass - self._terminal.on_become(passwd=auth_pass) + self._on_become(become_pass=auth_pass) self.queue_message("vvvv", "firing event: on_open_shell()") - self._terminal.on_open_shell() + self._on_open_shell() self.queue_message( "vvvv", "ssh connection has completed successfully" @@ -555,6 +700,43 @@ class Connection(NetworkConnectionBase): return self + def _on_become(self, become_pass=None): + """ + Wraps terminal.on_become() to handle + privilege escalation failures based on user preference + """ + on_become_error = self.get_option("become_errors") + try: + self._terminal.on_become(passwd=become_pass) + except AnsibleConnectionFailure: + if on_become_error == "ignore": + pass + elif on_become_error == "warn": + self.queue_message( + "warning", "on_become: privilege escalation failed" + ) + else: + raise + + def _on_open_shell(self): + """ + Wraps terminal.on_open_shell() to handle + terminal setting failures based on user preference + """ + on_terminal_error = self.get_option("terminal_errors") + try: + self._terminal.on_open_shell() + except AnsibleConnectionFailure: + if on_terminal_error == "ignore": + pass + elif on_terminal_error == "warn": + self.queue_message( + "warning", + "on_open_shell: failed to set terminal parameters", + ) + else: + raise + def close(self): """ Close the active connection to the device @@ -569,14 +751,19 @@ class Connection(NetworkConnectionBase): self._ssh_shell = None self.queue_message("debug", "cli session is now closed") - self.paramiko_conn.close() - self._paramiko_conn = None + self.ssh_type_conn.close() + self._ssh_type_conn = None self.queue_message( "debug", "ssh connection has been closed successfully" ) super(Connection, self).close() - def receive( + def _read_post_command_prompt_match(self): + time.sleep(self.get_option("persistent_buffer_read_timeout")) + data = self._ssh_shell.read_bulk_response() + return data if data else None + + def receive_paramiko( self, command=None, prompts=None, @@ -584,50 +771,29 @@ class Connection(NetworkConnectionBase): newline=True, prompt_retry_check=False, check_all=False, + strip_prompt=True, ): - """ - Handles receiving of output from command - """ - self._matched_prompt = None - self._matched_cmd_prompt = None + recv = BytesIO() - handled = False + cache_socket_timeout = self.get_option("persistent_command_timeout") + self._ssh_shell.settimeout(cache_socket_timeout) command_prompt_matched = False - matched_prompt_window = window_count = 0 - - # set terminal regex values for command prompt and errors in response - self._terminal_stderr_re = self._get_terminal_std_re( - "terminal_stderr_re" - ) - self._terminal_stdout_re = self._get_terminal_std_re( - "terminal_stdout_re" - ) - - cache_socket_timeout = self._ssh_shell.gettimeout() - command_timeout = self.get_option("persistent_command_timeout") - self._validate_timeout_value( - command_timeout, "persistent_command_timeout" - ) - if cache_socket_timeout != command_timeout: - self._ssh_shell.settimeout(command_timeout) - - buffer_read_timeout = self.get_option("persistent_buffer_read_timeout") - self._validate_timeout_value( - buffer_read_timeout, "persistent_buffer_read_timeout" - ) + handled = False + errored_response = None - self._log_messages("command: %s" % command) while True: if command_prompt_matched: try: signal.signal( signal.SIGALRM, self._handle_buffer_read_timeout ) - signal.setitimer(signal.ITIMER_REAL, buffer_read_timeout) + signal.setitimer( + signal.ITIMER_REAL, self._buffer_read_timeout + ) data = self._ssh_shell.recv(256) signal.alarm(0) self._log_messages( - "response-%s: %s" % (window_count + 1, data) + "response-%s: %s" % (self._window_count + 1, data) ) # if data is still received on channel it indicates the prompt string # is wrongly matched in between response chunks, continue to read @@ -636,16 +802,15 @@ class Connection(NetworkConnectionBase): # restart command_timeout timer signal.signal(signal.SIGALRM, self._handle_command_timeout) - signal.alarm(command_timeout) + signal.alarm(self._command_timeout) except AnsibleCmdRespRecv: # reset socket timeout to global timeout - self._ssh_shell.settimeout(cache_socket_timeout) return self._command_response else: data = self._ssh_shell.recv(256) self._log_messages( - "response-%s: %s" % (window_count + 1, data) + "response-%s: %s" % (self._window_count + 1, data) ) # when a channel stream is closed, received data will be empty if not data: @@ -657,18 +822,18 @@ class Connection(NetworkConnectionBase): window = self._strip(recv.read()) self._last_recv_window = window - window_count += 1 + self._window_count += 1 if prompts and not handled: handled = self._handle_prompt( window, prompts, answer, newline, False, check_all ) - matched_prompt_window = window_count + self._matched_prompt_window = self._window_count elif ( prompts and handled and prompt_retry_check - and matched_prompt_window + 1 == window_count + and self._matched_prompt_window + 1 == self._window_count ): # check again even when handled, if same prompt repeats in next window # (like in the case of a wrong enable password, etc) indicates @@ -686,17 +851,167 @@ class Connection(NetworkConnectionBase): % self._matched_cmd_prompt ) + if self._find_error(window): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = window + if self._find_prompt(window): + if errored_response: + raise AnsibleConnectionFailure(errored_response) self._last_response = recv.getvalue() resp = self._strip(self._last_response) - self._command_response = self._sanitize(resp, command) - if buffer_read_timeout == 0.0: + self._command_response = self._sanitize( + resp, command, strip_prompt + ) + if self._buffer_read_timeout == 0.0: # reset socket timeout to global timeout - self._ssh_shell.settimeout(cache_socket_timeout) return self._command_response else: command_prompt_matched = True + def receive_libssh( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + self._command_response = resp = b"" + command_prompt_matched = False + handled = False + errored_response = None + + while True: + + if command_prompt_matched: + data = self._read_post_command_prompt_match() + if data: + command_prompt_matched = False + else: + return self._command_response + else: + try: + data = self._ssh_shell.read_bulk_response() + # TODO: Should be ConnectionError when pylibssh drops Python 2 support + except OSError: + # Socket has closed + break + + if not data: + continue + self._last_recv_window = self._strip(data) + resp += self._last_recv_window + self._window_count += 1 + + self._log_messages("response-%s: %s" % (self._window_count, data)) + + if prompts and not handled: + handled = self._handle_prompt( + resp, prompts, answer, newline, False, check_all + ) + self._matched_prompt_window = self._window_count + elif ( + prompts + and handled + and prompt_retry_check + and self._matched_prompt_window + 1 == self._window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + resp, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_error(resp): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = resp + + if self._find_prompt(resp): + if errored_response: + raise AnsibleConnectionFailure(errored_response) + self._last_response = data + self._command_response += self._sanitize( + resp, command, strip_prompt + ) + command_prompt_matched = True + + def receive( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + """ + Handles receiving of output from command + """ + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_prompt_window = 0 + self._window_count = 0 + + # set terminal regex values for command prompt and errors in response + self._terminal_stderr_re = self._get_terminal_std_re( + "terminal_stderr_re" + ) + self._terminal_stdout_re = self._get_terminal_std_re( + "terminal_stdout_re" + ) + + self._command_timeout = self.get_option("persistent_command_timeout") + self._validate_timeout_value( + self._command_timeout, "persistent_command_timeout" + ) + + self._buffer_read_timeout = self.get_option( + "persistent_buffer_read_timeout" + ) + self._validate_timeout_value( + self._buffer_read_timeout, "persistent_buffer_read_timeout" + ) + + self._log_messages("command: %s" % command) + if self.ssh_type == "libssh": + response = self.receive_libssh( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + elif self.ssh_type == "paramiko": + response = self.receive_paramiko( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + + return response + @ensure_connect def send( self, @@ -707,10 +1022,20 @@ class Connection(NetworkConnectionBase): sendonly=False, prompt_retry_check=False, check_all=False, + strip_prompt=True, ): """ Sends the command to the device in the opened shell """ + # try cache first + if (not prompt) and (self._single_user_mode): + out = self.get_cache().lookup(command) + if out: + self.queue_message( + "vvvv", "cache hit for command: %s" % command + ) + return out + if check_all: prompt_len = len(to_list(prompt)) answer_len = len(to_list(answer)) @@ -727,9 +1052,32 @@ class Connection(NetworkConnectionBase): if sendonly: return response = self.receive( - command, prompt, answer, newline, prompt_retry_check, check_all + command, + prompt, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, ) - return to_text(response, errors="surrogate_then_replace") + response = to_text(response, errors="surrogate_then_replace") + + if (not prompt) and (self._single_user_mode): + if self._needs_cache_invalidation(command): + # invalidate the existing cache + if self.get_cache().keys(): + self.queue_message( + "vvvv", "invalidating existing cache" + ) + self.get_cache().invalidate() + else: + # populate cache + self.queue_message( + "vvvv", "populating cache for command: %s" % command + ) + self.get_cache().populate(command, response) + + return response except (socket.timeout, AttributeError): self.queue_message("error", traceback.format_exc()) raise AnsibleConnectionFailure( @@ -789,7 +1137,13 @@ class Connection(NetworkConnectionBase): single_prompt = True if not isinstance(answer, list): answer = [answer] - prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts] + try: + prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts] + except re.error as exc: + raise ConnectionError( + "Failed to compile one or more terminal prompt regexes: %s.\n" + "Prompts provided: %s" % (to_text(exc), prompts) + ) for index, regex in enumerate(prompts_regex): match = regex.search(resp) if match: @@ -801,13 +1155,12 @@ class Connection(NetworkConnectionBase): # if prompt_retry_check is enabled to check if same prompt is # repeated don't send answer again. if not prompt_retry_check: - prompt_answer = ( + prompt_answer = to_bytes( answer[index] if len(answer) > index else answer[0] ) - self._ssh_shell.sendall(b"%s" % prompt_answer) if newline: - self._ssh_shell.sendall(b"\r") prompt_answer += b"\r" + self._ssh_shell.sendall(prompt_answer) self._log_messages( "matched command prompt answer: %s" % prompt_answer ) @@ -818,7 +1171,7 @@ class Connection(NetworkConnectionBase): return True return False - def _sanitize(self, resp, command=None): + def _sanitize(self, resp, command=None, strip_prompt=True): """ Removes elements from the response before returning to the caller """ @@ -828,55 +1181,42 @@ class Connection(NetworkConnectionBase): continue for prompt in self._matched_prompt.strip().splitlines(): - if prompt.strip() in line: + if prompt.strip() in line and strip_prompt: break else: cleaned.append(line) + return b"\n".join(cleaned).strip() - def _find_prompt(self, response): - """Searches the buffered response for a matching command prompt - """ - errored_response = None - is_error_message = False - - for regex in self._terminal_stderr_re: - if regex.search(response): - is_error_message = True - - # Check if error response ends with command prompt if not - # receive it buffered prompt - for regex in self._terminal_stdout_re: - match = regex.search(response) - if match: - errored_response = response - self._matched_pattern = regex.pattern - self._matched_prompt = match.group() - self._log_messages( - "matched error regex '%s' from response '%s'" - % (self._matched_pattern, errored_response) - ) - break - - if not is_error_message: - for regex in self._terminal_stdout_re: - match = regex.search(response) - if match: - self._matched_pattern = regex.pattern - self._matched_prompt = match.group() - self._log_messages( - "matched cli prompt '%s' with regex '%s' from response '%s'" - % ( - self._matched_prompt, - self._matched_pattern, - response, - ) - ) - if not errored_response: - return True + def _find_error(self, response): + """Searches the buffered response for a matching error condition""" + for stderr_regex in self._terminal_stderr_re: + if stderr_regex.search(response): + self._log_messages( + "matched error regex (terminal_stderr_re) '%s' from response '%s'" + % (stderr_regex.pattern, response) + ) - if errored_response: - raise AnsibleConnectionFailure(errored_response) + self._log_messages( + "matched stdout regex (terminal_stdout_re) '%s' from error response '%s'" + % (self._matched_pattern, response) + ) + return True + + return False + + def _find_prompt(self, response): + """Searches the buffered response for a matching command prompt""" + for stdout_regex in self._terminal_stdout_re: + match = stdout_regex.search(response) + if match: + self._matched_pattern = stdout_regex.pattern + self._matched_prompt = match.group() + self._log_messages( + "matched cli prompt '%s' with regex '%s' from response '%s'" + % (self._matched_prompt, self._matched_pattern, response) + ) + return True return False @@ -912,7 +1252,7 @@ class Connection(NetworkConnectionBase): "'pattern' is a required key for option '%s'," " received option value is %s" % (option, item) ) - pattern = br"%s" % to_bytes(item["pattern"]) + pattern = rb"%s" % to_bytes(item["pattern"]) flag = item.get("flags", 0) if flag: flag = getattr(re, flag.split(".")[1]) @@ -922,3 +1262,125 @@ class Connection(NetworkConnectionBase): terminal_std_re = getattr(self._terminal, option) return terminal_std_re + + def copy_file( + self, source=None, destination=None, proto="scp", timeout=30 + ): + """Copies file over scp/sftp to remote device + + :param source: Source file path + :param destination: Destination file path on remote device + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.put_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.put(source, destination) + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.put(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_file(self, source=None, destination=None, proto="scp", timeout=30): + """Fetch file over scp/sftp from remote device + :param source: Source file path + :param destination: Destination file path + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + """Fetch file over scp/sftp from remote device""" + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.fetch_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + try: + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.get(source, destination) + except EOFError: + # This appears to be benign. + pass + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.get(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_cache(self): + if not self._cache: + # TO-DO: support jsonfile or other modes of caching with + # a configurable option + self._cache = cache_loader.get("ansible.netcommon.memory") + return self._cache + + def _is_in_config_mode(self): + """ + Check if the target device is in config mode by comparing + the current prompt with the platform's `terminal_config_prompt`. + Returns False if `terminal_config_prompt` is not defined. + + :returns: A boolean indicating if the device is in config mode or not. + """ + cfg_mode = False + cur_prompt = to_text( + self.get_prompt(), errors="surrogate_then_replace" + ).strip() + cfg_prompt = getattr(self._terminal, "terminal_config_prompt", None) + if cfg_prompt and cfg_prompt.match(cur_prompt): + cfg_mode = True + return cfg_mode + + def _needs_cache_invalidation(self, command): + """ + This method determines if it is necessary to invalidate + the existing cache based on whether the device has entered + configuration mode or if the last command sent to the device + is potentially capable of making configuration changes. + + :param command: The last command sent to the target device. + :returns: A boolean indicating if cache invalidation is required or not. + """ + invalidate = False + cfg_cmds = [] + try: + # AnsiblePlugin base class in Ansible 2.9 does not have has_option() method. + # TO-DO: use has_option() when we drop 2.9 support. + cfg_cmds = self.cliconf.get_option("config_commands") + except AttributeError: + cfg_cmds = [] + if (self._is_in_config_mode()) or (to_text(command) in cfg_cmds): + invalidate = True + return invalidate diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py new file mode 100644 index 00000000000..d572c30b90e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r""" +options: + import_modules: + type: boolean + description: + - Reduce CPU usage and network module execution time + by enabling direct execution. Instead of the module being packaged + and executed by the shell, it will be directly executed by the Ansible + control node using the same python interpreter as the Ansible process. + Note- Incompatible with C(asynchronous mode). + Note- Python 3 and Ansible 2.9.16 or greater required. + Note- With Ansible 2.9.x fully qualified modules names are required in tasks. + default: true + ini: + - section: ansible_network + key: import_modules + env: + - name: ANSIBLE_NETWORK_IMPORT_MODULES + vars: + - name: ansible_network_import_modules + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to + return from the remote device. If this timer is exceeded before the + command returns, the connection plugin will raise an exception and + close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this + option as it could create a security vulnerability by logging sensitive information in log file. + default: False + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py new file mode 100644 index 00000000000..a38a775b93f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py @@ -0,0 +1,185 @@ +# (c) 2012-2014, Michael DeHaan +# (c) 2015 Toshio Kuratomi +# (c) 2017, Peter Sprygada +# (c) 2017 Ansible Project +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os + +from ansible import constants as C +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.loader import connection_loader +from ansible.utils.display import Display +from ansible.utils.path import unfrackpath + +display = Display() + + +__all__ = ["NetworkConnectionBase"] + +BUFSIZE = 65536 + + +class NetworkConnectionBase(ConnectionBase): + """ + A base class for network-style connections. + """ + + force_persistence = True + # Do not use _remote_is_local in other connections + _remote_is_local = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(NetworkConnectionBase, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._messages = [] + self._conn_closed = False + + self._network_os = self._play_context.network_os + + self._local = connection_loader.get("local", play_context, "/dev/null") + self._local.set_options() + + self._sub_plugin = {} + self._cached_variables = (None, None, None) + + # reconstruct the socket_path and set instance values accordingly + self._ansible_playbook_pid = kwargs.get("ansible_playbook_pid") + self._update_connection_state() + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if not name.startswith("_"): + plugin = self._sub_plugin.get("obj") + if plugin: + method = getattr(plugin, name, None) + if method is not None: + return method + raise AttributeError( + "'%s' object has no attribute '%s'" + % (self.__class__.__name__, name) + ) + + def exec_command(self, cmd, in_data=None, sudoable=True): + return self._local.exec_command(cmd, in_data, sudoable) + + def queue_message(self, level, message): + """ + Adds a message to the queue of messages waiting to be pushed back to the controller process. + + :arg level: A string which can either be the name of a method in display, or 'log'. When + the messages are returned to task_executor, a value of log will correspond to + ``display.display(message, log_only=True)``, while another value will call ``display.[level](message)`` + """ + self._messages.append((level, message)) + + def pop_messages(self): + messages, self._messages = self._messages, [] + return messages + + def put_file(self, in_path, out_path): + """Transfer a file from local to remote""" + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + """Fetch a file from remote to local""" + return self._local.fetch_file(in_path, out_path) + + def reset(self): + """ + Reset the connection + """ + if self._socket_path: + self.queue_message( + "vvvv", + "resetting persistent connection for socket_path %s" + % self._socket_path, + ) + self.close() + self.queue_message("vvvv", "reset call on connection instance") + + def close(self): + self._conn_closed = True + if self._connected: + self._connected = False + + def get_options(self, hostvars=None): + options = super(NetworkConnectionBase, self).get_options( + hostvars=hostvars + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + options.update( + self._sub_plugin["obj"].get_options(hostvars=hostvars) + ) + except AttributeError: + pass + + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(NetworkConnectionBase, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + if self.get_option("persistent_log_messages"): + warning = ( + "Persistent connection logging is enabled for %s. This will log ALL interactions" + % self._play_context.remote_addr + ) + logpath = getattr(C, "DEFAULT_LOG_PATH") + if logpath is not None: + warning += " to %s" % logpath + self.queue_message( + "warning", + "%s and WILL NOT redact sensitive configuration like passwords. USE WITH CAUTION!" + % warning, + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + self._sub_plugin["obj"].set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + except AttributeError: + pass + + def _update_connection_state(self): + """ + Reconstruct the connection socket_path and check if it exists + + If the socket path exists then the connection is active and set + both the _socket_path value to the path and the _connected value + to True. If the socket path doesn't exist, leave the socket path + value to None and the _connected value to False + """ + ssh = connection_loader.get("ssh", class_only=True) + control_path = ssh._create_control_path( + self._play_context.remote_addr, + self._play_context.port, + self._play_context.remote_user, + self._play_context.connection, + self._ansible_playbook_pid, + ) + + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + socket_path = unfrackpath(control_path % dict(directory=tmp_path)) + + if os.path.exists(socket_path): + self._connected = True + self._socket_path = socket_path + + def _log_messages(self, message): + if self.get_option("persistent_log_messages"): + self.queue_message("log", message)