Add type annotation for connection plugins (#78552)

* Add type annotation for connection plugins

* Use new | syntax instead of Union/Optional

* Fix pep issue

* Use ParamSpec and other minor fixes

* Fix up ParmaSpec args and kwargs type
pull/77614/merge
Jordan Borean 1 year ago committed by GitHub
parent 67b78a17c4
commit c3f479e378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
minor_changes:
- Added Python type annotation to connection plugins

@ -55,6 +55,9 @@ class AnsiblePlugin(ABC):
# allow extra passthrough parameters
allow_extras = False
# Set by plugin loader
_load_name: str
def __init__(self):
self._options = {}
self._defs = None

@ -2,10 +2,12 @@
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
# (c) 2017, Peter Sprygada <psprygad@redhat.com>
# (c) 2017 Ansible Project
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
import collections.abc as c
import fcntl
import io
import os
import shlex
import typing as t
@ -15,7 +17,10 @@ from functools import wraps
from ansible import constants as C
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.playbook.play_context import PlayContext
from ansible.plugins import AnsiblePlugin
from ansible.plugins.become import BecomeBase
from ansible.plugins.shell import ShellBase
from ansible.utils.display import Display
from ansible.plugins.loader import connection_loader, get_shell_plugin
from ansible.utils.path import unfrackpath
@ -27,10 +32,15 @@ __all__ = ['ConnectionBase', 'ensure_connect']
BUFSIZE = 65536
P = t.ParamSpec('P')
T = t.TypeVar('T')
def ensure_connect(func):
def ensure_connect(
func: c.Callable[t.Concatenate[ConnectionBase, P], T],
) -> c.Callable[t.Concatenate[ConnectionBase, P], T]:
@wraps(func)
def wrapped(self, *args, **kwargs):
def wrapped(self: ConnectionBase, *args: P.args, **kwargs: P.kwargs) -> T:
if not self._connected:
self._connect()
return func(self, *args, **kwargs)
@ -57,9 +67,16 @@ class ConnectionBase(AnsiblePlugin):
supports_persistence = False
force_persistence = False
default_user = None
default_user: str | None = None
def __init__(self, play_context, new_stdin=None, shell=None, *args, **kwargs):
def __init__(
self,
play_context: PlayContext,
new_stdin: io.TextIOWrapper | None = None,
shell: ShellBase | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super(ConnectionBase, self).__init__()
@ -77,7 +94,7 @@ class ConnectionBase(AnsiblePlugin):
self.success_key = None
self.prompt = None
self._connected = False
self._socket_path = None
self._socket_path: str | None = None
# helper plugins
self._shell = shell
@ -87,10 +104,10 @@ class ConnectionBase(AnsiblePlugin):
shell_type = play_context.shell if play_context.shell else getattr(self, '_shell_type', None)
self._shell = get_shell_plugin(shell_type=shell_type, executable=self._play_context.executable)
self.become = None
self.become: BecomeBase | None = None
@property
def _new_stdin(self):
def _new_stdin(self) -> io.TextIOWrapper | None:
display.deprecated(
"The connection's stdin object is deprecated. "
"Call display.prompt_until(msg) instead.",
@ -98,21 +115,21 @@ class ConnectionBase(AnsiblePlugin):
)
return self.__new_stdin
def set_become_plugin(self, plugin):
def set_become_plugin(self, plugin: BecomeBase) -> None:
self.become = plugin
@property
def connected(self):
def connected(self) -> bool:
'''Read-only property holding whether the connection to the remote host is active or closed.'''
return self._connected
@property
def socket_path(self):
def socket_path(self) -> str | None:
'''Read-only property holding the connection socket path for this remote host'''
return self._socket_path
@staticmethod
def _split_ssh_args(argstring):
def _split_ssh_args(argstring: str) -> list[str]:
"""
Takes a string like '-o Foo=1 -o Bar="foo bar"' and returns a
list ['-o', 'Foo=1', '-o', 'Bar=foo bar'] that can be added to
@ -123,17 +140,17 @@ class ConnectionBase(AnsiblePlugin):
@property
@abstractmethod
def transport(self):
def transport(self) -> str:
"""String used to identify this Connection class from other classes"""
pass
@abstractmethod
def _connect(self):
def _connect(self: T) -> T:
"""Connect to the host we've been initialized with"""
@ensure_connect
@abstractmethod
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
"""Run a command on the remote host.
:arg cmd: byte string containing the command
@ -201,36 +218,36 @@ class ConnectionBase(AnsiblePlugin):
@ensure_connect
@abstractmethod
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to remote"""
pass
@ensure_connect
@abstractmethod
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from remote to local; callers are expected to have pre-created the directory chain for out_path"""
pass
@abstractmethod
def close(self):
def close(self) -> None:
"""Terminate the connection"""
pass
def connection_lock(self):
def connection_lock(self) -> None:
f = self._play_context.connection_lockfd
display.vvvv('CONNECTION: pid %d waiting for lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr)
fcntl.lockf(f, fcntl.LOCK_EX)
display.vvvv('CONNECTION: pid %d acquired lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr)
def connection_unlock(self):
def connection_unlock(self) -> None:
f = self._play_context.connection_lockfd
fcntl.lockf(f, fcntl.LOCK_UN)
display.vvvv('CONNECTION: pid %d released lock on %d' % (os.getpid(), f), host=self._play_context.remote_addr)
def reset(self):
def reset(self) -> None:
display.warning("Reset is not implemented for this connection")
def update_vars(self, variables):
def update_vars(self, variables: dict[str, t.Any]) -> None:
'''
Adds 'magic' variables relating to connections to the variable dictionary provided.
In case users need to access from the play, this is a legacy from runner.
@ -246,7 +263,7 @@ class ConnectionBase(AnsiblePlugin):
elif varname == 'ansible_connection':
# its me mom!
value = self._load_name
elif varname == 'ansible_shell_type':
elif varname == 'ansible_shell_type' and self._shell:
# its my cousin ...
value = self._shell._load_name
else:
@ -279,9 +296,15 @@ class NetworkConnectionBase(ConnectionBase):
# Do not use _remote_is_local in other connections
_remote_is_local = True
def __init__(self, play_context, new_stdin=None, *args, **kwargs):
def __init__(
self,
play_context: PlayContext,
new_stdin: io.TextIOWrapper | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super(NetworkConnectionBase, self).__init__(play_context, new_stdin, *args, **kwargs)
self._messages = []
self._messages: list[tuple[str, str]] = []
self._conn_closed = False
self._network_os = self._play_context.network_os
@ -289,7 +312,7 @@ class NetworkConnectionBase(ConnectionBase):
self._local = connection_loader.get('local', play_context, '/dev/null')
self._local.set_options()
self._sub_plugin = {}
self._sub_plugin: dict[str, t.Any] = {}
self._cached_variables = (None, None, None)
# reconstruct the socket_path and set instance values accordingly
@ -308,10 +331,10 @@ class NetworkConnectionBase(ConnectionBase):
return method
raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
return self._local.exec_command(cmd, in_data, sudoable)
def queue_message(self, level, message):
def queue_message(self, level: str, message: str) -> None:
"""
Adds a message to the queue of messages waiting to be pushed back to the controller process.
@ -321,19 +344,19 @@ class NetworkConnectionBase(ConnectionBase):
"""
self._messages.append((level, message))
def pop_messages(self):
def pop_messages(self) -> list[tuple[str, str]]:
messages, self._messages = self._messages, []
return messages
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
"""Transfer a file from local to remote"""
return self._local.put_file(in_path, out_path)
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
"""Fetch a file from remote to local"""
return self._local.fetch_file(in_path, out_path)
def reset(self):
def reset(self) -> None:
'''
Reset the connection
'''
@ -342,12 +365,17 @@ class NetworkConnectionBase(ConnectionBase):
self.close()
self.queue_message('vvvv', 'reset call on connection instance')
def close(self):
def close(self) -> None:
self._conn_closed = True
if self._connected:
self._connected = False
def set_options(self, task_keys=None, var_options=None, direct=None):
def set_options(
self,
task_keys: dict[str, t.Any] | None = None,
var_options: dict[str, t.Any] | None = None,
direct: dict[str, t.Any] | None = None,
) -> None:
super(NetworkConnectionBase, self).set_options(task_keys=task_keys, var_options=var_options, direct=direct)
if self.get_option('persistent_log_messages'):
warning = "Persistent connection logging is enabled for %s. This will log ALL interactions" % self._play_context.remote_addr
@ -362,7 +390,7 @@ class NetworkConnectionBase(ConnectionBase):
except AttributeError:
pass
def _update_connection_state(self):
def _update_connection_state(self) -> None:
'''
Reconstruct the connection socket_path and check if it exists
@ -385,6 +413,6 @@ class NetworkConnectionBase(ConnectionBase):
self._connected = True
self._socket_path = socket_path
def _log_messages(self, message):
def _log_messages(self, message: str) -> None:
if self.get_option('persistent_log_messages'):
self.queue_message('log', message)

@ -2,7 +2,7 @@
# (c) 2015, 2017 Toshio Kuratomi <tkuratomi@ansible.com>
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = '''
@ -24,6 +24,7 @@ import os
import pty
import shutil
import subprocess
import typing as t
import ansible.constants as C
from ansible.errors import AnsibleError, AnsibleFileNotFound
@ -43,7 +44,7 @@ class Connection(ConnectionBase):
transport = 'local'
has_pipelining = True
def __init__(self, *args, **kwargs):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super(Connection, self).__init__(*args, **kwargs)
self.cwd = None
@ -53,7 +54,7 @@ class Connection(ConnectionBase):
display.vv("Current user (uid=%s) does not seem to exist on this system, leaving user empty." % os.getuid())
self.default_user = ""
def _connect(self):
def _connect(self) -> Connection:
''' connect to the local host; nothing to do here '''
# Because we haven't made any remote connection we're running as
@ -65,7 +66,7 @@ class Connection(ConnectionBase):
self._connected = True
return self
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
''' run a command on the local host '''
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
@ -163,7 +164,7 @@ class Connection(ConnectionBase):
display.debug("done with local.exec_command()")
return (p.returncode, stdout, stderr)
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
''' transfer a file from local to local '''
super(Connection, self).put_file(in_path, out_path)
@ -181,7 +182,7 @@ class Connection(ConnectionBase):
except IOError as e:
raise AnsibleError("failed to transfer file to {0}: {1}".format(to_native(out_path), to_native(e)))
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
''' fetch a file from local to local -- for compatibility '''
super(Connection, self).fetch_file(in_path, out_path)
@ -189,6 +190,6 @@ class Connection(ConnectionBase):
display.vvv(u"FETCH {0} TO {1}".format(in_path, out_path), host=self._play_context.remote_addr)
self.put_file(in_path, out_path)
def close(self):
def close(self) -> None:
''' terminate the connection; nothing to do here '''
self._connected = False

@ -1,7 +1,7 @@
# (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = """
@ -293,6 +293,7 @@ import tempfile
import traceback
import fcntl
import re
import typing as t
from ansible.module_utils.compat.version import LooseVersion
from binascii import hexlify
@ -321,8 +322,12 @@ Are you sure you want to continue connecting (yes/no)?
# SSH Options Regex
SETTINGS_REGEX = re.compile(r'(\w+)(?:\s*=\s*|\s+)(.+)')
MissingHostKeyPolicy: type = object
if paramiko:
MissingHostKeyPolicy = paramiko.MissingHostKeyPolicy
class MyAddPolicy(object):
class MyAddPolicy(MissingHostKeyPolicy):
"""
Based on AutoAddPolicy in paramiko so we can determine when keys are added
@ -332,11 +337,11 @@ class MyAddPolicy(object):
local L{HostKeys} object, and saving it. This is used by L{SSHClient}.
"""
def __init__(self, connection):
def __init__(self, connection: Connection) -> None:
self.connection = connection
self._options = connection._options
def missing_host_key(self, client, hostname, key):
def missing_host_key(self, client, hostname, key) -> None:
if all((self.connection.get_option('host_key_checking'), not self.connection.get_option('host_key_auto_add'))):
@ -367,20 +372,20 @@ class MyAddPolicy(object):
# keep connection objects on a per host basis to avoid repeated attempts to reconnect
SSH_CONNECTION_CACHE = {} # type: dict[str, paramiko.client.SSHClient]
SFTP_CONNECTION_CACHE = {} # type: dict[str, paramiko.sftp_client.SFTPClient]
SSH_CONNECTION_CACHE: dict[str, paramiko.client.SSHClient] = {}
SFTP_CONNECTION_CACHE: dict[str, paramiko.sftp_client.SFTPClient] = {}
class Connection(ConnectionBase):
''' SSH based connections with Paramiko '''
transport = 'paramiko'
_log_channel = None
_log_channel: str | None = None
def _cache_key(self):
def _cache_key(self) -> str:
return "%s__%s__" % (self.get_option('remote_addr'), self.get_option('remote_user'))
def _connect(self):
def _connect(self) -> Connection:
cache_key = self._cache_key()
if cache_key in SSH_CONNECTION_CACHE:
self.ssh = SSH_CONNECTION_CACHE[cache_key]
@ -390,11 +395,11 @@ class Connection(ConnectionBase):
self._connected = True
return self
def _set_log_channel(self, name):
def _set_log_channel(self, name: str) -> None:
'''Mimic paramiko.SSHClient.set_log_channel'''
self._log_channel = name
def _parse_proxy_command(self, port=22):
def _parse_proxy_command(self, port: int = 22) -> dict[str, t.Any]:
proxy_command = None
# Parse ansible_ssh_common_args, specifically looking for ProxyCommand
ssh_args = [
@ -439,7 +444,7 @@ class Connection(ConnectionBase):
return sock_kwarg
def _connect_uncached(self):
def _connect_uncached(self) -> paramiko.SSHClient:
''' activates the connection object '''
if paramiko is None:
@ -453,10 +458,11 @@ class Connection(ConnectionBase):
# Set pubkey and hostkey algorithms to disable, the only manipulation allowed currently
# is keeping or omitting rsa-sha2 algorithms
# default_keys: t.Tuple[str] = ()
paramiko_preferred_pubkeys = getattr(paramiko.Transport, '_preferred_pubkeys', ())
paramiko_preferred_hostkeys = getattr(paramiko.Transport, '_preferred_keys', ())
use_rsa_sha2_algorithms = self.get_option('use_rsa_sha2_algorithms')
disabled_algorithms = {}
disabled_algorithms: t.Dict[str, t.Iterable[str]] = {}
if not use_rsa_sha2_algorithms:
if paramiko_preferred_pubkeys:
disabled_algorithms['pubkeys'] = tuple(a for a in paramiko_preferred_pubkeys if 'rsa-sha2' in a)
@ -533,7 +539,7 @@ class Connection(ConnectionBase):
return ssh
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
''' run a command on the remote host '''
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
@ -576,7 +582,7 @@ class Connection(ConnectionBase):
display.debug('Waiting for Privilege Escalation input')
chunk = chan.recv(bufsize)
display.debug("chunk is: %s" % chunk)
display.debug("chunk is: %r" % chunk)
if not chunk:
if b'unknown user' in become_output:
n_become_user = to_native(self.become.get_option('become_user'))
@ -606,14 +612,14 @@ class Connection(ConnectionBase):
no_prompt_out += become_output
no_prompt_err += become_output
except socket.timeout:
raise AnsibleError('ssh timed out waiting for privilege escalation.\n' + become_output)
raise AnsibleError('ssh timed out waiting for privilege escalation.\n' + to_text(become_output))
stdout = b''.join(chan.makefile('rb', bufsize))
stderr = b''.join(chan.makefile_stderr('rb', bufsize))
return (chan.recv_exit_status(), no_prompt_out + stdout, no_prompt_out + stderr)
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
''' transfer a file from local to remote '''
super(Connection, self).put_file(in_path, out_path)
@ -633,7 +639,7 @@ class Connection(ConnectionBase):
except IOError:
raise AnsibleError("failed to transfer file to %s" % out_path)
def _connect_sftp(self):
def _connect_sftp(self) -> paramiko.sftp_client.SFTPClient:
cache_key = "%s__%s__" % (self.get_option('remote_addr'), self.get_option('remote_user'))
if cache_key in SFTP_CONNECTION_CACHE:
@ -642,7 +648,7 @@ class Connection(ConnectionBase):
result = SFTP_CONNECTION_CACHE[cache_key] = self._connect().ssh.open_sftp()
return result
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
''' save a remote file to the specified path '''
super(Connection, self).fetch_file(in_path, out_path)
@ -659,7 +665,7 @@ class Connection(ConnectionBase):
except IOError:
raise AnsibleError("failed to transfer file from %s" % in_path)
def _any_keys_added(self):
def _any_keys_added(self) -> bool:
for hostname, keys in self.ssh._host_keys.items():
for keytype, key in keys.items():
@ -668,14 +674,14 @@ class Connection(ConnectionBase):
return True
return False
def _save_ssh_host_keys(self, filename):
def _save_ssh_host_keys(self, filename: str) -> None:
'''
not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks
don't complain about it :)
'''
if not self._any_keys_added():
return False
return
path = os.path.expanduser("~/.ssh")
makedirs_safe(path)
@ -698,13 +704,13 @@ class Connection(ConnectionBase):
if added_this_time:
f.write("%s %s %s\n" % (hostname, keytype, key.get_base64()))
def reset(self):
def reset(self) -> None:
if not self._connected:
return
self.close()
self._connect()
def close(self):
def close(self) -> None:
''' terminate the connection '''
cache_key = self._cache_key()

@ -1,7 +1,7 @@
# Copyright (c) 2018 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = """
@ -309,6 +309,7 @@ import base64
import json
import logging
import os
import typing as t
from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure, AnsibleError
@ -316,6 +317,7 @@ from ansible.errors import AnsibleFileNotFound
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import ShellModule as PowerShellPlugin
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display
from ansible.utils.hashing import sha1
@ -345,13 +347,16 @@ class Connection(ConnectionBase):
has_pipelining = True
allow_extras = True
def __init__(self, *args, **kwargs):
# Satifies mypy as this connection only ever runs with this plugin
_shell: PowerShellPlugin
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
self.always_pipeline_modules = True
self.has_native_async = True
self.runspace = None
self.host = None
self._last_pipeline = False
self.runspace: RunspacePool | None = None
self.host: PSHost | None = None
self._last_pipeline: PowerShell | None = None
self._shell_type = 'powershell'
super(Connection, self).__init__(*args, **kwargs)
@ -361,7 +366,7 @@ class Connection(ConnectionBase):
logging.getLogger('requests_credssp').setLevel(logging.INFO)
logging.getLogger('urllib3').setLevel(logging.INFO)
def _connect(self):
def _connect(self) -> Connection:
if not HAS_PYPSRP:
raise AnsibleError("pypsrp or dependencies are not installed: %s"
% to_native(PYPSRP_IMP_ERR))
@ -408,7 +413,7 @@ class Connection(ConnectionBase):
self._last_pipeline = None
return self
def reset(self):
def reset(self) -> None:
if not self._connected:
self.runspace = None
return
@ -424,26 +429,27 @@ class Connection(ConnectionBase):
self.runspace = None
self._connect()
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
super(Connection, self).exec_command(cmd, in_data=in_data,
sudoable=sudoable)
pwsh_in_data: bytes | str | None = None
if cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
# This is a PowerShell script encoded by the shell plugin, we will
# decode the script and execute it in the runspace instead of
# starting a new interpreter to save on time
b_command = base64.b64decode(cmd.split(" ")[-1])
script = to_text(b_command, 'utf-16-le')
in_data = to_text(in_data, errors="surrogate_or_strict", nonstring="passthru")
pwsh_in_data = to_text(in_data, errors="surrogate_or_strict", nonstring="passthru")
if in_data and in_data.startswith(u"#!"):
if pwsh_in_data and isinstance(pwsh_in_data, str) and pwsh_in_data.startswith("#!"):
# ANSIBALLZ wrapper, we need to get the interpreter and execute
# that as the script - note this won't work as basic.py relies
# on packages not available on Windows, once fixed we can enable
# this path
interpreter = to_native(in_data.splitlines()[0][2:])
interpreter = to_native(pwsh_in_data.splitlines()[0][2:])
# script = "$input | &'%s' -" % interpreter
# in_data = to_text(in_data)
raise AnsibleError("cannot run the interpreter '%s' on the psrp "
"connection plugin" % interpreter)
@ -458,12 +464,13 @@ class Connection(ConnectionBase):
# In other cases we want to execute the cmd as the script. We add on the 'exit $LASTEXITCODE' to ensure the
# rc is propagated back to the connection plugin.
script = to_text(u"%s\nexit $LASTEXITCODE" % cmd)
pwsh_in_data = in_data
display.vvv(u"PSRP: EXEC %s" % script, host=self._psrp_host)
rc, stdout, stderr = self._exec_psrp_script(script, in_data)
rc, stdout, stderr = self._exec_psrp_script(script, pwsh_in_data)
return rc, stdout, stderr
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
super(Connection, self).put_file(in_path, out_path)
out_path = self._shell._unquote(out_path)
@ -611,7 +618,7 @@ end {
raise AnsibleError("Remote sha1 hash %s does not match local hash %s"
% (to_native(remote_sha1), to_native(local_sha1)))
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
super(Connection, self).fetch_file(in_path, out_path)
display.vvv("FETCH %s TO %s" % (in_path, out_path),
host=self._psrp_host)
@ -689,7 +696,7 @@ if ($bytes_read -gt 0) {
display.warning("failed to close remote file stream of file "
"'%s': %s" % (in_path, to_native(stderr)))
def close(self):
def close(self) -> None:
if self.runspace and self.runspace.state == RunspacePoolState.OPENED:
display.vvvvv("PSRP CLOSE RUNSPACE: %s" % (self.runspace.id),
host=self._psrp_host)
@ -698,7 +705,7 @@ if ($bytes_read -gt 0) {
self._connected = False
self._last_pipeline = None
def _build_kwargs(self):
def _build_kwargs(self) -> None:
self._psrp_host = self.get_option('remote_addr')
self._psrp_user = self.get_option('remote_user')
self._psrp_pass = self.get_option('remote_password')
@ -802,7 +809,13 @@ if ($bytes_read -gt 0) {
option = self.get_option('_extras')['ansible_psrp_%s' % arg]
self._psrp_conn_kwargs[arg] = option
def _exec_psrp_script(self, script, input_data=None, use_local_scope=True, arguments=None):
def _exec_psrp_script(
self,
script: str,
input_data: bytes | str | t.Iterable | None = None,
use_local_scope: bool = True,
arguments: t.Iterable[str] | None = None,
) -> tuple[int, bytes, bytes]:
# Check if there's a command on the current pipeline that still needs to be closed.
if self._last_pipeline:
# Current pypsrp versions raise an exception if the current state was not RUNNING. We manually set it so we
@ -828,7 +841,7 @@ if ($bytes_read -gt 0) {
return rc, stdout, stderr
def _parse_pipeline_result(self, pipeline):
def _parse_pipeline_result(self, pipeline: PowerShell) -> tuple[int, bytes, bytes]:
"""
PSRP doesn't have the same concept as other protocols with its output.
We need some extra logic to convert the pipeline streams and host

@ -4,7 +4,7 @@
# Copyright (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = '''
@ -381,15 +381,18 @@ DOCUMENTATION = '''
- name: ansible_ssh_pkcs11_provider
'''
import collections.abc as c
import errno
import fcntl
import hashlib
import io
import os
import pty
import re
import shlex
import subprocess
import time
import typing as t
from functools import wraps
from ansible.errors import (
@ -410,6 +413,8 @@ from ansible.utils.path import unfrackpath, makedirs_safe
display = Display()
P = t.ParamSpec('P')
# error messages that indicate 255 return code is not from ssh itself.
b_NOT_SSH_ERRORS = (b'Traceback (most recent call last):', # Python-2.6 when there's an exception
# while invoking a script via -m
@ -427,7 +432,14 @@ class AnsibleControlPersistBrokenPipeError(AnsibleError):
pass
def _handle_error(remaining_retries, command, return_tuple, no_log, host, display=display):
def _handle_error(
remaining_retries: int,
command: bytes,
return_tuple: tuple[int, bytes, bytes],
no_log: bool,
host: str,
display: Display = display,
) -> None:
# sshpass errors
if command == b'sshpass':
@ -483,7 +495,9 @@ def _handle_error(remaining_retries, command, return_tuple, no_log, host, displa
display.vvv(msg, host=host)
def _ssh_retry(func):
def _ssh_retry(
func: c.Callable[t.Concatenate[Connection, P], tuple[int, bytes, bytes]],
) -> c.Callable[t.Concatenate[Connection, P], tuple[int, bytes, bytes]]:
"""
Decorator to retry ssh/scp/sftp in the case of a connection failure
@ -496,12 +510,12 @@ def _ssh_retry(func):
* retries limit reached
"""
@wraps(func)
def wrapped(self, *args, **kwargs):
def wrapped(self: Connection, *args: P.args, **kwargs: P.kwargs) -> tuple[int, bytes, bytes]:
remaining_tries = int(self.get_option('reconnection_retries')) + 1
cmd_summary = u"%s..." % to_text(args[0])
conn_password = self.get_option('password') or self._play_context.password
for attempt in range(remaining_tries):
cmd = args[0]
cmd = t.cast(list[bytes], args[0])
if attempt != 0 and conn_password and isinstance(cmd, list):
# If this is a retry, the fd/pipe for sshpass is closed, and we need a new one
self.sshpass_pipe = os.pipe()
@ -520,7 +534,7 @@ def _ssh_retry(func):
# 255 could be a failure from the ssh command itself
except (AnsibleControlPersistBrokenPipeError):
# Retry one more time because of the ControlPersist broken pipe (see #16731)
cmd = args[0]
cmd = t.cast(list[bytes], args[0])
if conn_password and isinstance(cmd, list):
# This is a retry, so the fd/pipe for sshpass is closed, and we need a new one
self.sshpass_pipe = os.pipe()
@ -568,15 +582,15 @@ class Connection(ConnectionBase):
transport = 'ssh'
has_pipelining = True
def __init__(self, *args, **kwargs):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super(Connection, self).__init__(*args, **kwargs)
# TODO: all should come from get_option(), but not might be set at this point yet
self.host = self._play_context.remote_addr
self.port = self._play_context.port
self.user = self._play_context.remote_user
self.control_path = None
self.control_path_dir = None
self.control_path: str | None = None
self.control_path_dir: str | None = None
# Windows operates differently from a POSIX connection/shell plugin,
# we need to set various properties to ensure SSH on Windows continues
@ -591,11 +605,17 @@ class Connection(ConnectionBase):
# put_file, and fetch_file methods, so we don't need to do any connection
# management here.
def _connect(self):
def _connect(self) -> Connection:
return self
@staticmethod
def _create_control_path(host, port, user, connection=None, pid=None):
def _create_control_path(
host: str | None,
port: int | None,
user: str | None,
connection: ConnectionBase | None = None,
pid: int | None = None,
) -> str:
'''Make a hash for the controlpath based on con attributes'''
pstring = '%s-%s-%s' % (host, port, user)
if connection:
@ -609,7 +629,7 @@ class Connection(ConnectionBase):
return cpath
@staticmethod
def _sshpass_available():
def _sshpass_available() -> bool:
global SSHPASS_AVAILABLE
# We test once if sshpass is available, and remember the result. It
@ -627,7 +647,7 @@ class Connection(ConnectionBase):
return SSHPASS_AVAILABLE
@staticmethod
def _persistence_controls(b_command):
def _persistence_controls(b_command: list[bytes]) -> tuple[bool, bool]:
'''
Takes a command array and scans it for ControlPersist and ControlPath
settings and returns two booleans indicating whether either was found.
@ -646,7 +666,7 @@ class Connection(ConnectionBase):
return controlpersist, controlpath
def _add_args(self, b_command, b_args, explanation):
def _add_args(self, b_command: list[bytes], b_args: t.Iterable[bytes], explanation: str) -> None:
"""
Adds arguments to the ssh command and displays a caller-supplied explanation of why.
@ -662,7 +682,7 @@ class Connection(ConnectionBase):
display.vvvvv(u'SSH: %s: (%s)' % (explanation, ')('.join(to_text(a) for a in b_args)), host=self.host)
b_command += b_args
def _build_command(self, binary, subsystem, *other_args):
def _build_command(self, binary: str, subsystem: str, *other_args: bytes | str) -> list[bytes]:
'''
Takes a executable (ssh, scp, sftp or wrapper) and optional extra arguments and returns the remote command
wrapped in local ssh shell commands and ready for execution.
@ -719,6 +739,7 @@ class Connection(ConnectionBase):
# be disabled if the client side doesn't support the option. However,
# sftp batch mode does not prompt for passwords so it must be disabled
# if not using controlpersist and using sshpass
b_args: t.Iterable[bytes]
if subsystem == 'sftp' and self.get_option('sftp_batch_mode'):
if conn_password:
b_args = [b'-o', b'BatchMode=no']
@ -818,7 +839,7 @@ class Connection(ConnectionBase):
return b_command
def _send_initial_data(self, fh, in_data, ssh_process):
def _send_initial_data(self, fh: io.IOBase, in_data: bytes, ssh_process: subprocess.Popen) -> None:
'''
Writes initial data to the stdin filehandle of the subprocess and closes
it. (The handle must be closed; otherwise, for example, "sftp -b -" will
@ -845,7 +866,7 @@ class Connection(ConnectionBase):
# Used by _run() to kill processes on failures
@staticmethod
def _terminate_process(p):
def _terminate_process(p: subprocess.Popen) -> None:
""" Terminate a process, ignoring errors """
try:
p.terminate()
@ -854,7 +875,7 @@ class Connection(ConnectionBase):
# This is separate from _run() because we need to do the same thing for stdout
# and stderr.
def _examine_output(self, source, state, b_chunk, sudoable):
def _examine_output(self, source: str, state: str, b_chunk: bytes, sudoable: bool) -> tuple[bytes, bytes]:
'''
Takes a string, extracts complete lines from it, tests to see if they
are a prompt, error message, etc., and sets appropriate flags in self.
@ -903,7 +924,7 @@ class Connection(ConnectionBase):
return b''.join(output), remainder
def _bare_run(self, cmd, in_data, sudoable=True, checkrc=True):
def _bare_run(self, cmd: list[bytes], in_data: bytes | None, sudoable: bool = True, checkrc: bool = True) -> tuple[int, bytes, bytes]:
'''
Starts the command and communicates with it until it ends.
'''
@ -949,7 +970,7 @@ class Connection(ConnectionBase):
else:
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdin = p.stdin
stdin = p.stdin # type: ignore[assignment] # stdin will be set and not None due to the calls above
except (OSError, IOError) as e:
raise AnsibleError('Unable to execute ssh command line on a controller due to: %s' % to_native(e))
@ -1199,13 +1220,13 @@ class Connection(ConnectionBase):
return (p.returncode, b_stdout, b_stderr)
@_ssh_retry
def _run(self, cmd, in_data, sudoable=True, checkrc=True):
def _run(self, cmd: list[bytes], in_data: bytes | None, sudoable: bool = True, checkrc: bool = True) -> tuple[int, bytes, bytes]:
"""Wrapper around _bare_run that retries the connection
"""
return self._bare_run(cmd, in_data, sudoable=sudoable, checkrc=checkrc)
@_ssh_retry
def _file_transport_command(self, in_path, out_path, sftp_action):
def _file_transport_command(self, in_path: str, out_path: str, sftp_action: str) -> tuple[int, bytes, bytes]:
# scp and sftp require square brackets for IPv6 addresses, but
# accept them for hostnames and IPv4 addresses too.
host = '[%s]' % self.host
@ -1293,7 +1314,7 @@ class Connection(ConnectionBase):
raise AnsibleError("failed to transfer file to %s %s:\n%s\n%s" %
(to_native(in_path), to_native(out_path), to_native(stdout), to_native(stderr)))
def _escape_win_path(self, path):
def _escape_win_path(self, path: str) -> str:
""" converts a Windows path to one that's supported by SFTP and SCP """
# If using a root path then we need to start with /
prefix = ""
@ -1306,7 +1327,7 @@ class Connection(ConnectionBase):
#
# Main public methods
#
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
''' run a command on the remote host '''
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
@ -1323,8 +1344,10 @@ class Connection(ConnectionBase):
# Make sure our first command is to set the console encoding to
# utf-8, this must be done via chcp to get utf-8 (65001)
cmd_parts = ["chcp.com", "65001", self._shell._SHELL_REDIRECT_ALLNULL, self._shell._SHELL_AND]
cmd_parts.extend(self._shell._encode_script(cmd, as_list=True, strict_mode=False, preserve_rc=False))
# union-attr ignores rely on internal powershell shell plugin details,
# this should be fixed at a future point in time.
cmd_parts = ["chcp.com", "65001", self._shell._SHELL_REDIRECT_ALLNULL, self._shell._SHELL_AND] # type: ignore[union-attr]
cmd_parts.extend(self._shell._encode_script(cmd, as_list=True, strict_mode=False, preserve_rc=False)) # type: ignore[union-attr]
cmd = ' '.join(cmd_parts)
# we can only use tty when we are not pipelining the modules. piping
@ -1338,6 +1361,7 @@ class Connection(ConnectionBase):
# to disable it as a troubleshooting method.
use_tty = self.get_option('use_tty')
args: tuple[str, ...]
if not in_data and sudoable and use_tty:
args = ('-tt', self.host, cmd)
else:
@ -1352,7 +1376,7 @@ class Connection(ConnectionBase):
return (returncode, stdout, stderr)
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> tuple[int, bytes, bytes]: # type: ignore[override] # Used by tests and would break API
''' transfer a file from local to remote '''
super(Connection, self).put_file(in_path, out_path)
@ -1368,7 +1392,7 @@ class Connection(ConnectionBase):
return self._file_transport_command(in_path, out_path, 'put')
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> tuple[int, bytes, bytes]: # type: ignore[override] # Used by tests and would break API
''' fetch a file from remote to local '''
super(Connection, self).fetch_file(in_path, out_path)
@ -1383,7 +1407,7 @@ class Connection(ConnectionBase):
return self._file_transport_command(in_path, out_path, 'get')
def reset(self):
def reset(self) -> None:
run_reset = False
self.host = self.get_option('host') or self._play_context.remote_addr
@ -1412,5 +1436,5 @@ class Connection(ConnectionBase):
self.close()
def close(self):
def close(self) -> None:
self._connected = False

@ -2,7 +2,7 @@
# Copyright (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import (absolute_import, division, print_function)
from __future__ import (annotations, absolute_import, division, print_function)
__metaclass__ = type
DOCUMENTATION = """
@ -170,6 +170,7 @@ import json
import tempfile
import shlex
import subprocess
import typing as t
from inspect import getfullargspec
from urllib.parse import urlunsplit
@ -190,6 +191,7 @@ from ansible.module_utils.common.text.converters import to_bytes, to_native, to_
from ansible.module_utils.six import binary_type
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _parse_clixml
from ansible.plugins.shell.powershell import ShellBase as PowerShellBase
from ansible.utils.hashing import secure_hash
from ansible.utils.display import Display
@ -245,14 +247,15 @@ class Connection(ConnectionBase):
has_pipelining = True
allow_extras = True
def __init__(self, *args, **kwargs):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
self.always_pipeline_modules = True
self.has_native_async = True
self.protocol = None
self.shell_id = None
self.protocol: winrm.Protocol | None = None
self.shell_id: str | None = None
self.delegate = None
self._shell: PowerShellBase
self._shell_type = 'powershell'
super(Connection, self).__init__(*args, **kwargs)
@ -262,7 +265,7 @@ class Connection(ConnectionBase):
logging.getLogger('requests_kerberos').setLevel(logging.INFO)
logging.getLogger('urllib3').setLevel(logging.INFO)
def _build_winrm_kwargs(self):
def _build_winrm_kwargs(self) -> None:
# this used to be in set_options, as win_reboot needs to be able to
# override the conn timeout, we need to be able to build the args
# after setting individual options. This is called by _connect before
@ -336,7 +339,7 @@ class Connection(ConnectionBase):
# Until pykerberos has enough goodies to implement a rudimentary kinit/klist, simplest way is to let each connection
# auth itself with a private CCACHE.
def _kerb_auth(self, principal, password):
def _kerb_auth(self, principal: str, password: str) -> None:
if password is None:
password = ""
@ -401,8 +404,8 @@ class Connection(ConnectionBase):
rc = child.exitstatus
else:
proc_mechanism = "subprocess"
password = to_bytes(password, encoding='utf-8',
errors='surrogate_or_strict')
b_password = to_bytes(password, encoding='utf-8',
errors='surrogate_or_strict')
display.vvvv("calling kinit with subprocess for principal %s"
% principal)
@ -417,7 +420,7 @@ class Connection(ConnectionBase):
"'%s': %s" % (self._kinit_cmd, to_native(err))
raise AnsibleConnectionFailure(err_msg)
stdout, stderr = p.communicate(password + b'\n')
stdout, stderr = p.communicate(b_password + b'\n')
rc = p.returncode != 0
if rc != 0:
@ -432,7 +435,7 @@ class Connection(ConnectionBase):
display.vvvvv("kinit succeeded for principal %s" % principal)
def _winrm_connect(self):
def _winrm_connect(self) -> winrm.Protocol:
'''
Establish a WinRM connection over HTTP/HTTPS.
'''
@ -491,7 +494,7 @@ class Connection(ConnectionBase):
else:
raise AnsibleError('No transport found for WinRM connection')
def _winrm_send_input(self, protocol, shell_id, command_id, stdin, eof=False):
def _winrm_send_input(self, protocol: winrm.Protocol, shell_id: str, command_id: str, stdin: bytes, eof: bool = False) -> None:
rq = {'env:Envelope': protocol._get_soap_header(
resource_uri='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd',
action='http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Send',
@ -505,7 +508,13 @@ class Connection(ConnectionBase):
stream['@End'] = 'true'
protocol.send_message(xmltodict.unparse(rq))
def _winrm_exec(self, command, args=(), from_exec=False, stdin_iterator=None):
def _winrm_exec(
self,
command: str,
args: t.Iterable[bytes] = (),
from_exec: bool = False,
stdin_iterator: t.Iterable[tuple[bytes, bool]] = None,
) -> winrm.Response:
if not self.protocol:
self.protocol = self._winrm_connect()
self._connected = True
@ -567,7 +576,7 @@ class Connection(ConnectionBase):
if command_id:
self.protocol.cleanup_command(self.shell_id, command_id)
def _connect(self):
def _connect(self) -> Connection:
if not HAS_WINRM:
raise AnsibleError("winrm or requests is not installed: %s" % to_native(WINRM_IMPORT_ERR))
@ -581,20 +590,20 @@ class Connection(ConnectionBase):
self._connected = True
return self
def reset(self):
def reset(self) -> None:
if not self._connected:
return
self.protocol = None
self.shell_id = None
self._connect()
def _wrapper_payload_stream(self, payload, buffer_size=200000):
def _wrapper_payload_stream(self, payload: bytes, buffer_size: int = 200000) -> t.Iterable[tuple[bytes, bool]]:
payload_bytes = to_bytes(payload)
byte_count = len(payload_bytes)
for i in range(0, byte_count, buffer_size):
yield payload_bytes[i:i + buffer_size], i + buffer_size >= byte_count
def exec_command(self, cmd, in_data=None, sudoable=True):
def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]:
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
cmd_parts = self._shell._encode_script(cmd, as_list=True, strict_mode=False, preserve_rc=False)
@ -622,7 +631,7 @@ class Connection(ConnectionBase):
return (result.status_code, result.std_out, result.std_err)
# FUTURE: determine buffer size at runtime via remote winrm config?
def _put_file_stdin_iterator(self, in_path, out_path, buffer_size=250000):
def _put_file_stdin_iterator(self, in_path: str, out_path: str, buffer_size: int = 250000) -> t.Iterable[tuple[bytes, bool]]:
in_size = os.path.getsize(to_bytes(in_path, errors='surrogate_or_strict'))
offset = 0
with open(to_bytes(in_path, errors='surrogate_or_strict'), 'rb') as in_file:
@ -635,9 +644,9 @@ class Connection(ConnectionBase):
yield b64_data, (in_file.tell() == in_size)
if offset == 0: # empty file, return an empty buffer + eof to close it
yield "", True
yield b"", True
def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> None:
super(Connection, self).put_file(in_path, out_path)
out_path = self._shell._unquote(out_path)
display.vvv('PUT "%s" TO "%s"' % (in_path, out_path), host=self._winrm_host)
@ -700,7 +709,7 @@ class Connection(ConnectionBase):
if not remote_sha1 == local_sha1:
raise AnsibleError("Remote sha1 hash {0} does not match local hash {1}".format(to_native(remote_sha1), to_native(local_sha1)))
def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> None:
super(Connection, self).fetch_file(in_path, out_path)
in_path = self._shell._unquote(in_path)
out_path = out_path.replace('\\', '/')
@ -767,7 +776,7 @@ class Connection(ConnectionBase):
if out_file:
out_file.close()
def close(self):
def close(self) -> None:
if self.protocol and self.shell_id:
display.vvvvv('WINRM CLOSE SHELL: %s' % self.shell_id, host=self._winrm_host)
self.protocol.close_shell(self.shell_id)

Loading…
Cancel
Save