Misc ssh agent fixes (#85238)

* Misc ssh-agent fixes

* Replace manual SIGALRM handling with new alarm_timeout context manager
* Misc error handling fixes to ssh-agent startup
* Add SSH_AGENT_EXECUTABLE config to ease failure mode testing
* 100% test coverage on agent startup failure code

Co-authored-by: Matt Clay <matt@mystile.com>

* make SSH Agent support internal

---------

Co-authored-by: Matt Clay <matt@mystile.com>
(cherry picked from commit 2a24633964)
pull/85255/head
Matt Davis 6 months ago committed by Matt Davis
parent 0576ff3e65
commit d63f9aa38d

@ -0,0 +1,4 @@
bugfixes:
- ssh agent - Fixed several potential startup hangs for badly-behaved or overloaded ssh agents.
minor_changes:
- ssh agent - Added ``SSH_AGENT_EXECUTABLE`` config to allow override of ssh-agent.

@ -0,0 +1,91 @@
from __future__ import annotations
import atexit
import os
import subprocess
from ansible import constants as C
from ansible._internal._errors import _alarm_timeout
from ansible._internal._ssh._ssh_agent import SshAgentClient
from ansible.cli import display
from ansible.errors import AnsibleError
from ansible.module_utils.common.process import get_bin_path
_SSH_AGENT_STDOUT_READ_TIMEOUT = 5 # seconds
def launch_ssh_agent() -> None:
"""If configured via `SSH_AGENT`, launch an ssh-agent for Ansible's use and/or verify access to an existing one."""
try:
_launch_ssh_agent()
except Exception as ex:
raise AnsibleError("Failed to launch ssh agent.") from ex
def _launch_ssh_agent() -> None:
ssh_agent_cfg = C.config.get_config_value('SSH_AGENT')
match ssh_agent_cfg:
case 'none':
display.debug('SSH_AGENT set to none')
return
case 'auto':
try:
ssh_agent_bin = get_bin_path(C.config.get_config_value('SSH_AGENT_EXECUTABLE'))
except ValueError as e:
raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary.') from e
ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent')
os.mkdir(ssh_agent_dir, 0o700)
sock = os.path.join(ssh_agent_dir, 'agent.sock')
display.vvv('SSH_AGENT: starting...')
try:
p = subprocess.Popen(
[ssh_agent_bin, '-D', '-s', '-a', sock],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
except OSError as e:
raise AnsibleError('Could not start ssh-agent.') from e
atexit.register(p.terminate)
help_text = f'The ssh-agent {ssh_agent_bin!r} might be an incompatible agent.'
expected_stdout = 'SSH_AUTH_SOCK'
try:
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(_SSH_AGENT_STDOUT_READ_TIMEOUT):
stdout = p.stdout.read(len(expected_stdout))
except _alarm_timeout.AnsibleTimeoutError as e:
display.error_as_warning(
msg=f'Timed out waiting for expected stdout {expected_stdout!r} from ssh-agent.',
exception=e,
help_text=help_text,
)
else:
if stdout != expected_stdout:
display.warning(
msg=f'The ssh-agent output {stdout!r} did not match expected {expected_stdout!r}.',
help_text=help_text,
)
if p.poll() is not None:
raise AnsibleError(
message='The ssh-agent terminated prematurely.',
help_text=f'{help_text}\n\nReturn Code: {p.returncode}\nStandard Error:\n{p.stderr.read()}',
)
display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}')
case _:
sock = ssh_agent_cfg
try:
with SshAgentClient(sock) as client:
client.list()
except Exception as e:
raise AnsibleError(f'Could not communicate with ssh-agent using auth sock {sock!r}.') from e
os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock

@ -106,21 +106,19 @@ class SshAgentFailure(RuntimeError):
# NOTE: Classes below somewhat represent "Data Type Representations Used in the SSH Protocols" # NOTE: Classes below somewhat represent "Data Type Representations Used in the SSH Protocols"
# as specified by RFC4251 # as specified by RFC4251
@t.runtime_checkable @t.runtime_checkable
class SupportsToBlob(t.Protocol): class SupportsToBlob(t.Protocol):
def to_blob(self) -> bytes: def to_blob(self) -> bytes: ...
...
@t.runtime_checkable @t.runtime_checkable
class SupportsFromBlob(t.Protocol): class SupportsFromBlob(t.Protocol):
@classmethod @classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self: def from_blob(cls, blob: memoryview | bytes) -> t.Self: ...
...
@classmethod @classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]: def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]: ...
...
def _split_blob(blob: memoryview | bytes, length: int) -> tuple[memoryview | bytes, memoryview | bytes]: def _split_blob(blob: memoryview | bytes, length: int) -> tuple[memoryview | bytes, memoryview | bytes]:
@ -304,10 +302,12 @@ class PrivateKeyMsg(Msg):
return EcdsaPrivateKeyMsg( return EcdsaPrivateKeyMsg(
getattr(KeyAlgo, f'ECDSA{key_size}'), getattr(KeyAlgo, f'ECDSA{key_size}'),
unicode_string(f'nistp{key_size}'), unicode_string(f'nistp{key_size}'),
binary_string(private_key.public_key().public_bytes( binary_string(
private_key.public_key().public_bytes(
encoding=serialization.Encoding.X962, encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint format=serialization.PublicFormat.UncompressedPoint,
)), )
),
mpint(ecdsa_pn.private_value), mpint(ecdsa_pn.private_value),
) )
case Ed25519PrivateKey(): case Ed25519PrivateKey():
@ -318,7 +318,7 @@ class PrivateKeyMsg(Msg):
private_bytes = private_key.private_bytes( private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.Raw, encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw, format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption(),
) )
return Ed25519PrivateKeyMsg( return Ed25519PrivateKeyMsg(
KeyAlgo.ED25519, KeyAlgo.ED25519,
@ -376,14 +376,14 @@ class Ed25519PrivateKeyMsg(PrivateKeyMsg):
@dataclasses.dataclass @dataclasses.dataclass
class PublicKeyMsg(Msg): class PublicKeyMsg(Msg):
@staticmethod @staticmethod
def get_dataclass( def get_dataclass(type: KeyAlgo) -> type[
type: KeyAlgo t.Union[
) -> type[t.Union[
RSAPublicKeyMsg, RSAPublicKeyMsg,
EcdsaPublicKeyMsg, EcdsaPublicKeyMsg,
Ed25519PublicKeyMsg, Ed25519PublicKeyMsg,
DSAPublicKeyMsg DSAPublicKeyMsg,
]]: ]
]:
match type: match type:
case KeyAlgo.RSA: case KeyAlgo.RSA:
return RSAPublicKeyMsg return RSAPublicKeyMsg
@ -401,29 +401,14 @@ class PublicKeyMsg(Msg):
type: KeyAlgo = self.type type: KeyAlgo = self.type
match type: match type:
case KeyAlgo.RSA: case KeyAlgo.RSA:
return RSAPublicNumbers( return RSAPublicNumbers(self.e, self.n).public_key()
self.e,
self.n
).public_key()
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521: case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
curve = _ECDSA_KEY_TYPE[KeyAlgo(type)] curve = _ECDSA_KEY_TYPE[KeyAlgo(type)]
return EllipticCurvePublicKey.from_encoded_point( return EllipticCurvePublicKey.from_encoded_point(curve(), self.Q)
curve(),
self.Q
)
case KeyAlgo.ED25519: case KeyAlgo.ED25519:
return Ed25519PublicKey.from_public_bytes( return Ed25519PublicKey.from_public_bytes(self.enc_a)
self.enc_a
)
case KeyAlgo.DSA: case KeyAlgo.DSA:
return DSAPublicNumbers( return DSAPublicNumbers(self.y, DSAParameterNumbers(self.p, self.q, self.g)).public_key()
self.y,
DSAParameterNumbers(
self.p,
self.q,
self.g
)
).public_key()
case _: case _:
raise NotImplementedError(type) raise NotImplementedError(type)
@ -437,32 +422,32 @@ class PublicKeyMsg(Msg):
mpint(dsa_pn.parameter_numbers.p), mpint(dsa_pn.parameter_numbers.p),
mpint(dsa_pn.parameter_numbers.q), mpint(dsa_pn.parameter_numbers.q),
mpint(dsa_pn.parameter_numbers.g), mpint(dsa_pn.parameter_numbers.g),
mpint(dsa_pn.y) mpint(dsa_pn.y),
) )
case EllipticCurvePublicKey(): case EllipticCurvePublicKey():
return EcdsaPublicKeyMsg( return EcdsaPublicKeyMsg(
getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'), getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'),
unicode_string(f'nistp{public_key.curve.key_size}'), unicode_string(f'nistp{public_key.curve.key_size}'),
binary_string(public_key.public_bytes( binary_string(
public_key.public_bytes(
encoding=serialization.Encoding.X962, encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint format=serialization.PublicFormat.UncompressedPoint,
)) )
),
) )
case Ed25519PublicKey(): case Ed25519PublicKey():
return Ed25519PublicKeyMsg( return Ed25519PublicKeyMsg(
KeyAlgo.ED25519, KeyAlgo.ED25519,
binary_string(public_key.public_bytes( binary_string(
public_key.public_bytes(
encoding=serialization.Encoding.Raw, encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw, format=serialization.PublicFormat.Raw,
)) )
),
) )
case RSAPublicKey(): case RSAPublicKey():
rsa_pn: RSAPublicNumbers = public_key.public_numbers() rsa_pn: RSAPublicNumbers = public_key.public_numbers()
return RSAPublicKeyMsg( return RSAPublicKeyMsg(KeyAlgo.RSA, mpint(rsa_pn.e), mpint(rsa_pn.n))
KeyAlgo.RSA,
mpint(rsa_pn.e),
mpint(rsa_pn.n)
)
case _: case _:
raise NotImplementedError(public_key) raise NotImplementedError(public_key)
@ -473,10 +458,7 @@ class PublicKeyMsg(Msg):
msg.comments = unicode_string('') msg.comments = unicode_string('')
k = msg.to_blob() k = msg.to_blob()
digest.update(k) digest.update(k)
return binascii.b2a_base64( return binascii.b2a_base64(digest.digest(), newline=False).rstrip(b'=').decode('utf-8')
digest.digest(),
newline=False
).rstrip(b'=').decode('utf-8')
@dataclasses.dataclass(order=True, slots=True) @dataclasses.dataclass(order=True, slots=True)
@ -519,9 +501,7 @@ class KeyList(Msg):
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.nkeys != len(self.keys): if self.nkeys != len(self.keys):
raise SshAgentFailure( raise SshAgentFailure("agent: invalid number of keys received for identities list")
"agent: invalid number of keys received for identities list"
)
@dataclasses.dataclass(order=True, slots=True) @dataclasses.dataclass(order=True, slots=True)
@ -535,8 +515,7 @@ class PublicKeyMsgList(Msg):
return len(self.keys) return len(self.keys)
@classmethod @classmethod
def from_blob(cls, blob: memoryview | bytes) -> t.Self: def from_blob(cls, blob: memoryview | bytes) -> t.Self: ...
...
@classmethod @classmethod
def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]: def consume_from_blob(cls, blob: memoryview | bytes) -> tuple[t.Self, memoryview | bytes]:
@ -546,22 +525,16 @@ class PublicKeyMsgList(Msg):
key_blob, key_blob_length, comment_blob = cls._consume_field(blob) key_blob, key_blob_length, comment_blob = cls._consume_field(blob)
peek_key_algo, _length, _blob = cls._consume_field(key_blob) peek_key_algo, _length, _blob = cls._consume_field(key_blob)
pub_key_msg_cls = PublicKeyMsg.get_dataclass( pub_key_msg_cls = PublicKeyMsg.get_dataclass(KeyAlgo(bytes(peek_key_algo).decode('utf-8')))
KeyAlgo(bytes(peek_key_algo).decode('utf-8'))
)
_fv, comment_blob_length, blob = cls._consume_field(comment_blob) _fv, comment_blob_length, blob = cls._consume_field(comment_blob)
key_plus_comment = ( key_plus_comment = prev_blob[4 : (4 + key_blob_length) + (4 + comment_blob_length)]
prev_blob[4: (4 + key_blob_length) + (4 + comment_blob_length)]
)
args.append(pub_key_msg_cls.from_blob(key_plus_comment)) args.append(pub_key_msg_cls.from_blob(key_plus_comment))
return cls(args), b"" return cls(args), b""
@staticmethod @staticmethod
def _consume_field( def _consume_field(blob: memoryview | bytes) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
blob: memoryview | bytes
) -> tuple[memoryview | bytes, uint32, memoryview | bytes]:
length = uint32.from_blob(blob[:4]) length = uint32.from_blob(blob[:4])
blob = blob[4:] blob = blob[4:]
data, rest = _split_blob(blob, length) data, rest = _split_blob(blob, length)
@ -584,7 +557,7 @@ class SshAgentClient:
self, self,
exc_type: type[BaseException] | None, exc_type: type[BaseException] | None,
exc_value: BaseException | None, exc_value: BaseException | None,
traceback: types.TracebackType | None traceback: types.TracebackType | None,
) -> None: ) -> None:
self.close() self.close()
@ -598,16 +571,11 @@ class SshAgentClient:
return resp return resp
def remove_all(self) -> None: def remove_all(self) -> None:
self.send( self.send(ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob())
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_ALL_IDENTITIES.to_blob()
)
def remove(self, public_key: CryptoPublicKey) -> None: def remove(self, public_key: CryptoPublicKey) -> None:
key_blob = PublicKeyMsg.from_public_key(public_key).to_blob() key_blob = PublicKeyMsg.from_public_key(public_key).to_blob()
self.send( self.send(ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() + uint32(len(key_blob)).to_blob() + key_blob)
ProtocolMsgNumbers.SSH_AGENTC_REMOVE_IDENTITY.to_blob() +
uint32(len(key_blob)).to_blob() + key_blob
)
def add( def add(
self, self,
@ -619,13 +587,9 @@ class SshAgentClient:
key_msg = PrivateKeyMsg.from_private_key(private_key) key_msg = PrivateKeyMsg.from_private_key(private_key)
key_msg.comments = unicode_string(comments or '') key_msg.comments = unicode_string(comments or '')
if lifetime: if lifetime:
key_msg.constraints += constraints( key_msg.constraints += constraints([ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]).to_blob() + uint32(lifetime).to_blob()
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_LIFETIME]
).to_blob() + uint32(lifetime).to_blob()
if confirm: if confirm:
key_msg.constraints += constraints( key_msg.constraints += constraints([ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]).to_blob()
[ProtocolMsgNumbers.SSH_AGENT_CONSTRAIN_CONFIRM]
).to_blob()
if key_msg.constraints: if key_msg.constraints:
msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_ID_CONSTRAINED.to_blob() msg = ProtocolMsgNumbers.SSH_AGENTC_ADD_ID_CONSTRAINED.to_blob()
@ -638,9 +602,7 @@ class SshAgentClient:
req = ProtocolMsgNumbers.SSH_AGENTC_REQUEST_IDENTITIES.to_blob() req = ProtocolMsgNumbers.SSH_AGENTC_REQUEST_IDENTITIES.to_blob()
r = memoryview(bytearray(self.send(req))) r = memoryview(bytearray(self.send(req)))
if r[0] != ProtocolMsgNumbers.SSH_AGENT_IDENTITIES_ANSWER: if r[0] != ProtocolMsgNumbers.SSH_AGENT_IDENTITIES_ANSWER:
raise SshAgentFailure( raise SshAgentFailure('agent: non-identities answer received for identities list')
'agent: non-identities answer received for identities list'
)
return KeyList.from_blob(r[1:]) return KeyList.from_blob(r[1:])
def __contains__(self, public_key: CryptoPublicKey) -> bool: def __contains__(self, public_key: CryptoPublicKey) -> bool:
@ -649,7 +611,7 @@ class SshAgentClient:
@functools.cache @functools.cache
def _key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]: def key_data_into_crypto_objects(key_data: bytes, passphrase: bytes | None) -> tuple[CryptoPrivateKey, CryptoPublicKey, str]:
private_key = serialization.ssh.load_ssh_private_key(key_data, passphrase) private_key = serialization.ssh.load_ssh_private_key(key_data, passphrase)
public_key = private_key.public_key() public_key = private_key.public_key()
fingerprint = PublicKeyMsg.from_public_key(public_key).fingerprint fingerprint = PublicKeyMsg.from_public_key(public_key).fingerprint

@ -7,7 +7,6 @@ from __future__ import annotations
import locale import locale
import os import os
import signal
import sys import sys
# We overload the ``ansible`` adhoc command to provide the functionality for # We overload the ``ansible`` adhoc command to provide the functionality for
@ -75,8 +74,6 @@ def initialize_locale():
initialize_locale() initialize_locale()
import atexit
import errno import errno
import getpass import getpass
import subprocess import subprocess
@ -112,17 +109,17 @@ from ansible.module_utils.six import string_types
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.file import is_executable from ansible.module_utils.common.file import is_executable
from ansible.module_utils.common.process import get_bin_path
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret, VaultSecretsContext from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret, VaultSecretsContext
from ansible.plugins.loader import add_all_plugin_dirs, init_plugin_loader from ansible.plugins.loader import add_all_plugin_dirs, init_plugin_loader
from ansible.release import __version__ from ansible.release import __version__
from ansible.utils._ssh_agent import SshAgentClient
from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path
from ansible.utils.path import unfrackpath from ansible.utils.path import unfrackpath
from ansible.vars.manager import VariableManager from ansible.vars.manager import VariableManager
from ansible.module_utils._internal import _deprecator from ansible.module_utils._internal import _deprecator
from ansible._internal._ssh import _agent_launch
try: try:
import argcomplete import argcomplete
@ -131,77 +128,6 @@ except ImportError:
HAS_ARGCOMPLETE = False HAS_ARGCOMPLETE = False
_SSH_AGENT_STDOUT_READ_TIMEOUT = 5 # seconds
def _ssh_agent_timeout_handler(signum, frame):
raise TimeoutError
def _launch_ssh_agent() -> None:
ssh_agent_cfg = C.config.get_config_value('SSH_AGENT')
match ssh_agent_cfg:
case 'none':
display.debug('SSH_AGENT set to none')
return
case 'auto':
try:
ssh_agent_bin = get_bin_path('ssh-agent')
except ValueError as e:
raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary') from e
ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent')
os.mkdir(ssh_agent_dir, 0o700)
sock = os.path.join(ssh_agent_dir, 'agent.sock')
display.vvv('SSH_AGENT: starting...')
try:
p = subprocess.Popen(
[ssh_agent_bin, '-D', '-s', '-a', sock],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except OSError as e:
raise AnsibleError(
f'Could not start ssh-agent: {e}'
) from e
if p.poll() is not None:
raise AnsibleError(
f'Could not start ssh-agent: rc={p.returncode} stderr="{p.stderr.read().decode()}"'
)
old_sigalrm_handler = signal.signal(signal.SIGALRM, _ssh_agent_timeout_handler)
signal.alarm(_SSH_AGENT_STDOUT_READ_TIMEOUT)
try:
stdout = p.stdout.read(13)
except TimeoutError:
stdout = b''
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_sigalrm_handler)
if stdout != b'SSH_AUTH_SOCK':
display.warning(
f'The first 13 characters of stdout did not match the '
f'expected SSH_AUTH_SOCK. This may not be the right binary, '
f'or an incompatible agent: {stdout.decode()}'
)
display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}')
atexit.register(p.terminate)
case _:
sock = ssh_agent_cfg
try:
with SshAgentClient(sock) as client:
client.list()
except Exception as e:
raise AnsibleError(
f'Could not communicate with ssh-agent using auth sock {sock}: {e}'
) from e
os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock
class CLI(ABC): class CLI(ABC):
""" code behind bin/ansible* programs """ """ code behind bin/ansible* programs """
@ -636,10 +562,7 @@ class CLI(ABC):
loader.set_vault_secrets(vault_secrets) loader.set_vault_secrets(vault_secrets)
if self.USES_CONNECTION: if self.USES_CONNECTION:
try: _agent_launch.launch_ssh_agent()
_launch_ssh_agent()
except Exception as e:
raise AnsibleError('Failed to launch ssh agent.') from e
# create the inventory, and filter it based on the subset specified (if any) # create the inventory, and filter it based on the subset specified (if any)
inventory = InventoryManager(loader=loader, sources=options['inventory'], cache=(not options.get('flush_cache'))) inventory = InventoryManager(loader=loader, sources=options['inventory'], cache=(not options.get('flush_cache')))

@ -1962,6 +1962,14 @@ SSH_AGENT:
env: [{name: ANSIBLE_SSH_AGENT}] env: [{name: ANSIBLE_SSH_AGENT}]
ini: [{key: ssh_agent, section: connection}] ini: [{key: ssh_agent, section: connection}]
version_added: '2.19' version_added: '2.19'
SSH_AGENT_EXECUTABLE:
name: Executable to start for the ansible-managed SSH agent
description: When ``SSH_AGENT`` is ``auto``, the path or name of the ssh agent executable to start.
default: ssh-agent
type: str
env: [ { name: ANSIBLE_SSH_AGENT_EXECUTABLE } ]
ini: [ { key: ssh_agent_executable, section: connection } ]
version_added: '2.19'
SSH_AGENT_KEY_LIFETIME: SSH_AGENT_KEY_LIFETIME:
name: Set a maximum lifetime when adding identities to an agent name: Set a maximum lifetime when adding identities to an agent
description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain

@ -447,7 +447,7 @@ from ansible.plugins.connection import ConnectionBase, BUFSIZE
from ansible.plugins.shell.powershell import _replace_stderr_clixml from ansible.plugins.shell.powershell import _replace_stderr_clixml
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.path import unfrackpath, makedirs_safe from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils._ssh_agent import SshAgentClient, _key_data_into_crypto_objects from ansible._internal._ssh import _ssh_agent
try: try:
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
@ -766,12 +766,12 @@ class Connection(ConnectionBase):
key_data = self.get_option('private_key') key_data = self.get_option('private_key')
passphrase = self.get_option('private_key_passphrase') passphrase = self.get_option('private_key_passphrase')
private_key, public_key, fingerprint = _key_data_into_crypto_objects( private_key, public_key, fingerprint = _ssh_agent.key_data_into_crypto_objects(
to_bytes(key_data), to_bytes(key_data),
to_bytes(passphrase) if passphrase else None, to_bytes(passphrase) if passphrase else None,
) )
with SshAgentClient(auth_sock) as client: with _ssh_agent.SshAgentClient(auth_sock) as client:
if public_key not in client: if public_key not in client:
display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host) display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host)
client.add( client.add(

@ -3,7 +3,7 @@ from __future__ import annotations
import os import os
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.utils._ssh_agent import SshAgentClient from ansible._internal._ssh._ssh_agent import SshAgentClient
from cryptography.hazmat.primitives.serialization import ssh from cryptography.hazmat.primitives.serialization import ssh

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.utils._ssh_agent import PublicKeyMsg from ansible._internal._ssh._ssh_agent import PublicKeyMsg
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text

@ -0,0 +1,8 @@
#!/usr/bin/env bash
# write > 13 chars to satisfy the check
echo bogusbogusbogus
# wait long enough for the parent process to fail accessing the socket file we didn't create
# this ensures consistent failure on fast/slow test hosts
sleep 3

@ -0,0 +1,6 @@
#!/usr/bin/env bash
echo 'bogus stderr output' >&2
echo 'SSH_AUTH_S'
exit 42

@ -1,23 +1,17 @@
- delegate_to: localhost - delegate_to: localhost
block: block:
# bcrypt is required for the ssh_keygen action
- name: install bcrypt - name: install bcrypt
pip: pip:
name: bcrypt name: bcrypt
register: bcrypt register: bcrypt
- tempfile:
path: "{{ lookup('env', 'OUTPUT_DIR') }}"
state: directory
register: tmpdir
- import_tasks: tests.yml - import_tasks: tests.yml
environment:
ANSIBLE_FORCE_COLOR: no
always: always:
- name: uninstall bcrypt - name: uninstall bcrypt
pip: pip:
name: bcrypt name: bcrypt
state: absent state: absent
when: bcrypt is changed when: bcrypt is changed
- file:
path: tmpdir.path
state: absent

@ -29,14 +29,14 @@
vars: vars:
pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}' pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}'
- command: ssh-agent -D -s -a '{{ tmpdir.path }}/agent.sock' - command: ssh-agent -D -s -a '{{ output_dir }}/agent.sock'
async: 30 async: 30
poll: 0 poll: 0
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml - command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
environment: environment:
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
ANSIBLE_SSH_AGENT: '{{ tmpdir.path }}/agent.sock' ANSIBLE_SSH_AGENT: '{{ output_dir }}/agent.sock'
register: existing register: existing
- assert: - assert:
@ -47,3 +47,21 @@
'SSH: SSH_AGENT adding' in existing.stdout 'SSH: SSH_AGENT adding' in existing.stdout
- >- - >-
'exists in agent' in existing.stdout 'exists in agent' in existing.stdout
- name: test various agent failure modes
shell: ansible localhost -m ping
environment:
ANSIBLE_SSH_AGENT: auto
ANSIBLE_SSH_AGENT_EXECUTABLE: "{{ role_path }}/fake_agents/ssh-agent-{{ item }}"
ignore_errors: true
register: failures
loop: [not-found, hangs, incompatible, truncated-early-exit, bad-shebang]
- assert:
that:
- failures.results | select('success') | length == 0
- failures.results[0].stderr is search 'SSH_AGENT set to auto, but cannot find ssh-agent binary'
- failures.results[1].stderr is search 'Timed out waiting for expected stdout .* from ssh-agent'
- failures.results[2].stderr is search 'The ssh-agent output .* did not match expected'
- failures.results[3].stderr is search 'The ssh-agent terminated prematurely'
- failures.results[4].stderr is search 'Could not start ssh-agent'

@ -111,6 +111,7 @@ test/integration/targets/win_script/files/test_script.ps1 pslint:PSAvoidUsingWri
test/integration/targets/win_script/files/test_script_removes_file.ps1 pslint:PSCustomUseLiteralPath test/integration/targets/win_script/files/test_script_removes_file.ps1 pslint:PSCustomUseLiteralPath
test/integration/targets/win_script/files/test_script_with_args.ps1 pslint:PSAvoidUsingWriteHost # Keep test/integration/targets/win_script/files/test_script_with_args.ps1 pslint:PSAvoidUsingWriteHost # Keep
test/integration/targets/win_script/files/test_script_with_splatting.ps1 pslint:PSAvoidUsingWriteHost # Keep test/integration/targets/win_script/files/test_script_with_splatting.ps1 pslint:PSAvoidUsingWriteHost # Keep
test/integration/targets/ssh_agent/fake_agents/ssh-agent-bad-shebang shebang # required for test
test/lib/ansible_test/_data/requirements/sanity.pslint.ps1 pslint:PSCustomUseLiteralPath # Uses wildcards on purpose test/lib/ansible_test/_data/requirements/sanity.pslint.ps1 pslint:PSCustomUseLiteralPath # Uses wildcards on purpose
test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py pylint:arguments-renamed test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py pylint:arguments-renamed
test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/module_utils/WebRequest.psm1 pslint!skip test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/module_utils/WebRequest.psm1 pslint!skip

Loading…
Cancel
Save