Add ssh-agent launching, and ssh-agent python client

pull/82181/head
Matt Martz 1 year ago
parent 3e82ed307b
commit 79ab013f26
No known key found for this signature in database
GPG Key ID: 40832D88E9FC91D8

@ -0,0 +1,5 @@
minor_changes:
- ssh-agent - InventoryManager is capable of spawning or reusing an ssh-agent, allowing plugins to interact with the ssh-agent.
Additionally a pure python ssh-client has been added, enabling easy interaction with the agent. The ssh connection plugin contains
new functionality via ``ansible_ssh_private_key`` and ``ansible_ssh_private_key_passphrase``, for loading an SSH private key into
the agent from a variable.

@ -1897,6 +1897,24 @@ SHOW_CUSTOM_STATS:
ini:
- {key: show_custom_stats, section: defaults}
type: bool
SSH_AGENT:
name: Manage an SSH Agent
description: Manage an SSH Agent via Ansible. A configuration of ``none`` will not interact with an agent,
``auto`` will start and destroy an agent during the run, and a path to an SSH_AUTH_SOCK will
allow interaction with a pre-existing agent.
default: none
type: string
env: [{name: ANSIBLE_SSH_AGENT}]
ini: [{key: ssh_agent, section: connection}]
version_added: '2.17'
SSH_AGENT_KEY_LIFETIME:
name: Set a maximum lifetime when adding identities to an agent
description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain
in the agent.
type: int
env: [{name: ANSIBLE_SSH_AGENT_KEY_LIFETIME}]
ini: [{key: ssh_agent_key_lifetime, section: connection}]
version_added: '2.17'
STRING_TYPE_FILTERS:
name: Filters to preserve strings
default: [string, to_json, to_nice_json, to_yaml, to_nice_yaml, ppretty, json]

@ -18,11 +18,13 @@
#############################################
from __future__ import annotations
import atexit
import fnmatch
import os
import sys
import re
import itertools
import subprocess
import traceback
from operator import attrgetter
@ -32,12 +34,14 @@ from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError
from ansible.inventory.data import InventoryData
from ansible.module_utils.six import string_types
from ansible.module_utils.common.process import get_bin_path
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.parsing.utils.addresses import parse_address
from ansible.plugins.loader import inventory_loader
from ansible.utils.helpers import deduplicate_list
from ansible.utils.path import unfrackpath
from ansible.utils.display import Display
from ansible.utils.ssh_agent import Client
from ansible.utils.vars import combine_vars
from ansible.vars.plugins import get_vars_from_inventory_sources
@ -161,6 +165,8 @@ class InventoryManager(object):
else:
self._sources = sources
self._launch_ssh_agent()
# get to work!
if parse:
self.parse_sources(cache=cache)
@ -168,6 +174,50 @@ class InventoryManager(object):
self._cached_dynamic_hosts = []
self._cached_dynamic_grouping = []
def _launch_ssh_agent(self):
ssh_agent_cfg = C.config.get_config_value('SSH_AGENT')
match ssh_agent_cfg:
case 'none':
return
case 'auto':
ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent')
os.mkdir(ssh_agent_dir, 0o700)
try:
ssh_agent_bin = get_bin_path('ssh-agent', required=True)
except ValueError as e:
raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary') from e
sock = os.path.join(ssh_agent_dir, 'agent.sock')
p = subprocess.Popen(
[ssh_agent_bin, '-D', '-s', '-a', sock],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if p.poll() is not None:
raise AnsibleError(
f'Could not start ssh-agent: (rc={p.returncode}) {p.stderr}'
)
if (stdout := p.stdout.read(13)) != b'SSH_AUTH_SOCK':
display.warn(
f'The first 13 characters of stdout did not match the '
f'expected SSH_AUTH_SOCK. This may not be the right binary, '
f'or an incompatible agent: {stdout.decode()}'
)
display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}')
atexit.register(p.terminate)
case _:
sock = ssh_agent_cfg
try:
with Client(sock) as client:
client.list()
except Exception as e:
raise AnsibleError(
f'Could not communicate with ssh-agent using auth sock {sock}: {e}'
)
os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock
@property
def localhost(self):
return self._inventory.get_host('localhost')

@ -264,7 +264,30 @@ DOCUMENTATION = """
cli:
- name: private_key_file
option: '--private-key'
private_key:
description:
- private key contents in PEM format. Requires the SSH_AGENT configuration to be enabled.
type: string
ini:
- section: defaults
key: private_key
env:
- name: ANSIBLE_PRIVATE_KEY
vars:
- name: ansible_private_key
version_added: '2.17'
private_key_passphrase:
description:
- private key passphrase, dependent on ``private_key``.
type: string
ini:
- section: defaults
key: private_key_passphrase
env:
- name: ANSIBLE_PRIVATE_KEY_PASSPHRASE
vars:
- name: ansible_private_key_passphrase
version_added: '2.17'
control_path:
description:
- This is the location to save SSH's ControlPath sockets, it uses SSH's variable substitution.
@ -380,6 +403,8 @@ import time
import typing as t
from functools import wraps
from ansible import constants as C
from ansible.errors import (
AnsibleAuthenticationFailure,
AnsibleConnectionFailure,
@ -392,6 +417,15 @@ from ansible.plugins.connection import ConnectionBase, BUFSIZE
from ansible.plugins.shell.powershell import _parse_clixml
from ansible.utils.display import Display
from ansible.utils.path import unfrackpath, makedirs_safe
from ansible.utils.ssh_agent import load_private_key, Client, PublicKeyMsg
try:
from cryptography.hazmat.primitives import serialization
except ImportError:
HAS_CRYPTOGRAPHY = False
else:
HAS_CRYPTOGRAPHY = True
display = Display()
@ -583,6 +617,8 @@ class Connection(ConnectionBase):
self.module_implementation_preferences = ('.ps1', '.exe', '')
self.allow_executable = False
self._populated_agent = None
# The connection is created by running ssh/scp/sftp from the exec_command,
# put_file, and fetch_file methods, so we don't need to do any connection
# management here.
@ -664,6 +700,49 @@ class Connection(ConnectionBase):
display.vvvvv(u'SSH: %s: (%s)' % (explanation, ')('.join(to_text(a) for a in b_args)), host=self.host)
b_command += b_args
def _populate_agent(self):
if self._populated_agent:
return self._populated_agent
if (auth_sock := C.config.get_config_value('SSH_AGENT')) == 'none':
raise AnsibleError('Cannot utilize private_key with SSH_AGENT disabled')
key_data = self.get_option('private_key')
passphrase = self.get_option('private_key_passphrase')
private_key = load_private_key(
to_bytes(key_data),
to_bytes(passphrase) if passphrase else None
)
public_key = private_key.public_key()
pubkey_msg = PublicKeyMsg.from_public_key(public_key)
fingerprint = pubkey_msg.fingerprint()
with Client(auth_sock) as client:
if (public_key := private_key.public_key()) not in client:
display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host)
client.add(
private_key,
'[added by ansible]',
C.config.get_config_value('SSH_AGENT_KEY_LIFETIME'),
)
else:
display.vvv(f'SSH: SSH_AGENT {fingerprint} exists in agent', host=self.host)
# Write the public key to disk, to be provided as IdentityFile.
# This allows ssh to pick an explicit key in the agent to use,
# preventing ssh from attempting all keys in the agent.
pubkey_path = self._populated_agent = os.path.join(
C.DEFAULT_LOCAL_TMP,
fingerprint.replace('/', '-') + '.pub'
)
if os.path.exists(pubkey_path):
return pubkey_path
fd = os.open(pubkey_path, os.O_CREAT | os.O_WRONLY, mode=0o400)
with os.fdopen(fd, 'wb') as f:
f.write(public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH
))
return self._populated_agent
def _build_command(self, binary: str, subsystem: str, *other_args: bytes | str) -> list[bytes]:
"""
Takes a executable (ssh, scp, sftp or wrapper) and optional extra arguments and returns the remote command
@ -748,8 +827,12 @@ class Connection(ConnectionBase):
b_args = (b"-o", b"Port=" + to_bytes(self.port, nonstring='simplerepr', errors='surrogate_or_strict'))
self._add_args(b_command, b_args, u"ANSIBLE_REMOTE_PORT/remote_port/ansible_port set")
key = self.get_option('private_key_file')
if key:
if self.get_option('private_key'):
key = self._populate_agent()
b_args = (b'-o', b'IdentitiesOnly=yes', b'-o', b'IdentityFile="' + to_bytes(key, errors='surrogate_or_strict') + b'"')
self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY/private_key set")
elif (key := self.get_option('private_key_file')):
b_args = (b"-o", b'IdentityFile="' + to_bytes(os.path.expanduser(key), errors='surrogate_or_strict') + b'"')
self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY_FILE/private_key_file/ansible_ssh_private_key_file set")

@ -19,7 +19,7 @@ from __future__ import annotations
import os
import shutil
from errno import EEXIST
from errno import EEXIST, ENOENT
from ansible.errors import AnsibleError
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
@ -127,7 +127,7 @@ def cleanup_tmp_file(path, warn=False):
elif os.path.isfile(path):
os.unlink(path)
except Exception as e:
if warn:
if warn and getattr(e, 'errno', None) != ENOENT:
# Importing here to avoid circular import
from ansible.utils.display import Display
display = Display()

@ -0,0 +1,645 @@
# Copyright: Contributors to the Ansible project
# BSD 3 Clause License (see licenses/BSD-3-Clause.txt or https://opensource.org/license/bsd-3-clause/)
from __future__ import annotations
import binascii
import collections.abc
import copy
import dataclasses
import enum
import hashlib
import socket
import typing as t
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.dsa import (
DSAParameterNumbers,
DSAPrivateKey,
DSAPublicKey,
DSAPublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurve,
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
SECP256R1,
SECP384R1,
SECP521R1,
)
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey,
RSAPublicKey,
RSAPublicNumbers,
)
from cryptography.hazmat.primitives.serialization import ssh
CryptoPublicKey = t.Union[
DSAPublicKey,
EllipticCurvePublicKey,
Ed25519PublicKey,
RSAPublicKey,
]
CryptoPrivateKey = t.Union[
DSAPrivateKey,
EllipticCurvePrivateKey,
Ed25519PrivateKey,
RSAPrivateKey,
]
except ImportError:
HAS_CRYPTOGRAPHY = False
else:
HAS_CRYPTOGRAPHY = True
class SshAgentFailure(Exception):
...
class mpint(int):
...
class byte(int):
...
class constraints(bytes):
...
class Protocol(enum.IntEnum):
# Responses
SSH_AGENT_FAILURE = 5
SSH_AGENT_SUCCESS = 6
SSH_AGENT_EXTENSION_FAILURE = 28
SSH_AGENT_IDENTITIES_ANSWER = 12
SSH_AGENT_SIGN_RESPONSE = 14
# Constraints
SSH_AGENT_CONSTRAIN_LIFETIME = 1
SSH_AGENT_CONSTRAIN_CONFIRM = 2
SSH_AGENT_CONSTRAIN_EXTENSION = 255
# Requests
SSH_AGENTC_REQUEST_IDENTITIES = 11
SSH_AGENTC_SIGN_REQUEST = 13
SSH_AGENTC_ADD_IDENTITY = 17
SSH_AGENTC_REMOVE_IDENTITY = 18
SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19
SSH_AGENTC_ADD_SMARTCARD_KEY = 20
SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21
SSH_AGENTC_LOCK = 22
SSH_AGENTC_UNLOCK = 23
SSH_AGENTC_ADD_ID_CONSTRAINED = 25
SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26
SSH_AGENTC_EXTENSION = 27
class KeyAlgo(str, enum.Enum):
RSA = "ssh-rsa"
DSA = "ssh-dss"
ECDSA256 = "ecdsa-sha2-nistp256"
SKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
ECDSA384 = "ecdsa-sha2-nistp384"
ECDSA521 = "ecdsa-sha2-nistp521"
ED25519 = "ssh-ed25519"
SKED25519 = "sk-ssh-ed25519@openssh.com"
RSASHA256 = "rsa-sha2-256"
RSASHA512 = "rsa-sha2-512"
@property
def main_type(self):
match self:
case self.RSA:
return 'RSA'
case self.DSA:
return 'DSA'
case self.ECDSA256 | self.ECDSA384 | self.ECDSA521:
return 'ECDSA'
case self.ED25519:
return 'ED25519'
case _:
raise NotImplementedError(self.name)
if HAS_CRYPTOGRAPHY:
_ECDSA_KEY_TYPE: dict[KeyAlgo, type[EllipticCurve]] = {
KeyAlgo.ECDSA256: SECP256R1,
KeyAlgo.ECDSA384: SECP384R1,
KeyAlgo.ECDSA521: SECP521R1,
}
@dataclasses.dataclass
class Msg:
...
@dataclasses.dataclass(order=True, slots=True)
class AgentLockMsg(Msg):
passphrase: bytes
@dataclasses.dataclass
class PrivateKeyMsg(Msg):
@staticmethod
def from_private_key(private_key):
match private_key:
case RSAPrivateKey():
pn = private_key.private_numbers()
return RSAPrivateKeyMsg(
KeyAlgo.RSA,
pn.public_numbers.n,
pn.public_numbers.e,
pn.d,
pn.iqmp,
pn.p,
pn.q,
)
case DSAPrivateKey():
pn = private_key.private_numbers()
return DSAPrivateKeyMsg(
KeyAlgo.DSA,
pn.public_numbers.parameter_numbers.p,
pn.public_numbers.parameter_numbers.q,
pn.public_numbers.parameter_numbers.g,
pn.public_numbers.y,
pn.x,
)
case EllipticCurvePrivateKey():
pn = private_key.private_numbers()
key_size = private_key.key_size
return EcdsaPrivateKeyMsg(
getattr(KeyAlgo, f'ECDSA{key_size}'),
f'nistp{key_size}',
private_key.public_key().public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
),
pn.private_value,
)
case Ed25519PrivateKey():
public_bytes = private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
)
return Ed25519PrivateKeyMsg(
KeyAlgo.ED25519,
public_bytes,
private_bytes + public_bytes,
)
case _:
raise NotImplementedError(private_key)
@dataclasses.dataclass(order=True, slots=True)
class RSAPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
n: mpint
e: mpint
d: mpint
iqmp: mpint
p: mpint
q: mpint
comments: str = dataclasses.field(default='', compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class DSAPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
p: mpint
q: mpint
g: mpint
y: mpint
x: mpint
comments: str = dataclasses.field(default='', compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class EcdsaPrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
ecdsa_curve_name: str
Q: bytes
d: mpint
comments: str = dataclasses.field(default='', compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass(order=True, slots=True)
class Ed25519PrivateKeyMsg(PrivateKeyMsg):
type: KeyAlgo
enc_a: bytes
k_env_a: bytes
comments: str = dataclasses.field(default='', compare=False)
constraints: constraints = dataclasses.field(default=constraints(b''))
@dataclasses.dataclass
class PublicKeyMsg(Msg):
@staticmethod
def get_dataclass(
type: KeyAlgo
) -> type[t.Union[
RSAPublicKeyMsg,
EcdsaPublicKeyMsg,
Ed25519PublicKeyMsg,
DSAPublicKeyMsg
]]:
match type:
case KeyAlgo.RSA:
return RSAPublicKeyMsg
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
return EcdsaPublicKeyMsg
case KeyAlgo.ED25519:
return Ed25519PublicKeyMsg
case KeyAlgo.DSA:
return DSAPublicKeyMsg
case _:
raise NotImplementedError(type)
def public_key(self) -> CryptoPublicKey:
type = self.type # type: ignore[attr-defined]
match type:
case KeyAlgo.RSA:
return RSAPublicNumbers(
self.e, # type: ignore[attr-defined]
self.n # type: ignore[attr-defined]
).public_key()
case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521:
curve = _ECDSA_KEY_TYPE[KeyAlgo(type)]
return EllipticCurvePublicKey.from_encoded_point(
curve(),
self.Q # type: ignore[attr-defined]
)
case KeyAlgo.ED25519:
return Ed25519PublicKey.from_public_bytes(
self.enc_a # type: ignore[attr-defined]
)
case KeyAlgo.DSA:
return DSAPublicNumbers(
self.y, # type: ignore[attr-defined]
DSAParameterNumbers(
self.p, # type: ignore[attr-defined]
self.q, # type: ignore[attr-defined]
self.g # type: ignore[attr-defined]
)
).public_key()
case _:
raise NotImplementedError(type)
@staticmethod
def from_public_key(public_key):
match public_key:
case DSAPublicKey():
pn = public_key.public_numbers()
return DSAPublicKeyMsg(
KeyAlgo.DSA,
pn.parameter_numbers.p,
pn.parameter_numbers.q,
pn.parameter_numbers.g,
pn.y
)
case EllipticCurvePublicKey():
return EcdsaPublicKeyMsg(
getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'),
f'nistp{public_key.curve.key_size}',
public_key.public_bytes(
encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint
)
)
case Ed25519PublicKey():
return Ed25519PublicKeyMsg(
KeyAlgo.ED25519,
public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
)
case RSAPublicKey():
pn = public_key.public_numbers()
return RSAPublicKeyMsg(
KeyAlgo.RSA,
pn.e,
pn.n
)
case _:
raise NotImplementedError(public_key)
def fingerprint(self):
digest = hashlib.sha256()
msg = copy.copy(self)
msg.comments = ''
k = encode(msg)
digest.update(k)
return binascii.b2a_base64(
digest.digest(),
newline=False
).rstrip(b'=').decode('utf-8')
@dataclasses.dataclass(order=True, slots=True)
class RSAPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
e: mpint
n: mpint
comments: str = dataclasses.field(default='', compare=False)
@dataclasses.dataclass(order=True, slots=True)
class DSAPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
p: mpint
q: mpint
g: mpint
y: mpint
comments: str = dataclasses.field(default='', compare=False)
@dataclasses.dataclass(order=True, slots=True)
class EcdsaPublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
ecdsa_curve_name: str
Q: bytes
comments: str = dataclasses.field(default='', compare=False)
@dataclasses.dataclass(order=True, slots=True)
class Ed25519PublicKeyMsg(PublicKeyMsg):
type: KeyAlgo
enc_a: bytes
comments: str = dataclasses.field(default='', compare=False)
@dataclasses.dataclass(order=True, slots=True)
class KeyList(Msg):
nkeys: int
keys: list[PublicKeyMsg]
def __init__(self, nkeys, *args):
self.nkeys = nkeys
self.keys = args
def _to_bytes(val: int, length: int) -> bytes:
return val.to_bytes(length=length, byteorder='big')
def _from_bytes(val: bytes) -> int:
return int.from_bytes(val, byteorder='big')
def _to_mpint(val: int) -> bytes:
if val < 0:
raise ValueError("negative mpint not allowed")
if not val:
return b""
nbytes = (val.bit_length() + 8) // 8
ret = bytearray(_to_bytes(val, nbytes))
ret[:0] = _to_bytes(len(ret), 4)
return ret
def _from_mpint(val: bytes) -> int:
if val and val[0] > 127:
raise ValueError("Invalid data")
return _from_bytes(val)
def encode_dataclass(msg: Msg) -> collections.abc.Iterator[bytes]:
for field in dataclasses.fields(msg):
fv = getattr(msg, field.name)
match field.type:
case 'mpint':
yield _to_mpint(fv)
case 'int': # uint32
yield _to_bytes(fv, 4)
case 'bool' | 'byte' | 'Protocol':
yield _to_bytes(fv, 1)
case 'str' | 'KeyAlgo':
if fv:
fv = fv.encode('utf-8')
yield _to_bytes(len(fv), 4)
yield fv
case 'bytes':
if fv:
yield _to_bytes(len(fv), 4)
yield fv
case 'constraints':
yield fv
case _:
raise NotImplementedError(field.type)
def parse_annotation(type: t.Any) -> tuple[str, str]:
if type.count('[') > 1:
raise NotImplementedError()
main, _sep, sub = type.removesuffix(']').partition('[')
return main, sub
def _consume_field(
blob: memoryview,
type: t.Any | None = None
) -> tuple[memoryview, int, memoryview]:
match type:
case 'int':
length = 4
case 'bool' | 'byte' | 'Protocol':
length = 1
case _:
length = _from_bytes(blob[:4])
blob = blob[4:]
return blob[:length], length, blob[length:]
def decode_dataclass(blob: memoryview, dataclass: type[Msg]) -> Msg:
fi = 0
args: list[t.Any] = []
fields = dataclasses.fields(dataclass)
while blob:
field = fields[fi]
prev_blob = blob
fv, length, blob = _consume_field(blob, type=field.type)
main_type, sub_type = parse_annotation(field.type)
match main_type:
case 'mpint':
args.append(_from_mpint(fv))
case 'int': # uint32
args.append(_from_bytes(fv))
case 'Protocol':
args.append(Protocol(_from_bytes(fv)))
case 'bool' | 'byte':
args.append(_from_bytes(fv))
case 'KeyAlgo':
args.append(KeyAlgo(fv.tobytes().decode('utf-8')))
case 'str':
args.append(fv.tobytes().decode('utf-8'))
case 'bytes':
args.append(bytes(fv))
case 'list':
# Lists should always be last
match sub_type:
case 'PublicKeyMsg':
peek, _length, _blob = _consume_field(fv)
sub = PublicKeyMsg.get_dataclass(
KeyAlgo(peek.tobytes().decode('utf-8'))
)
_fv, cl, blob = _consume_field(blob)
key_plus_comment = (
prev_blob[4:length + cl + 8]
)
case _:
raise NotImplementedError(sub_type)
args.append(decode_dataclass(key_plus_comment, sub))
fi -= 1 # We are in a list, don't move to the next field
case _:
raise NotImplementedError(field.type)
fi += 1
return dataclass(*args)
def encode(msg: Protocol | Msg) -> bytes:
if isinstance(msg, Protocol):
payload = bytes([msg])
else:
payload = b''.join(encode_dataclass(msg))
return payload
class Client:
def __init__(self, auth_sock: str):
self._auth_sock = auth_sock
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._sock.connect(auth_sock)
def terminate(self):
self._ssh_agent.terminate()
def close(self):
self._sock.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def send(self, msg: bytes) -> bytes:
length = _to_bytes(len(msg), 4)
self._sock.sendall(length + msg)
bufsize = _from_bytes(self._sock.recv(4))
resp = self._sock.recv(bufsize)
if resp[0] == Protocol.SSH_AGENT_FAILURE:
raise SshAgentFailure('agent: failure')
return resp
def remove_all(self):
msg = encode(Protocol.SSH_AGENTC_REMOVE_ALL_IDENTITIES)
self.send(msg)
return True
def remove(self, public_key: CryptoPublicKey):
msg = encode(Protocol.SSH_AGENTC_REMOVE_IDENTITY)
key_blob = encode(
PublicKeyMsg.from_public_key(public_key)
)
msg += _to_bytes(len(key_blob), 4)
msg += key_blob
self.send(msg)
return True
def add(
self,
private_key: CryptoPrivateKey,
comments: str | None = None,
lifetime: int | None = None,
confirm: bool | None = None,
):
key_msg = PrivateKeyMsg.from_private_key(private_key)
key_msg.comments = comments or ''
if lifetime:
key_msg.constraints += constraints(
[Protocol.SSH_AGENT_CONSTRAIN_LIFETIME]
) + _to_bytes(lifetime, 4)
if confirm:
key_msg.constraints += constraints(
[Protocol.SSH_AGENT_CONSTRAIN_CONFIRM]
)
if key_msg.constraints:
msg = encode(Protocol.SSH_AGENTC_ADD_ID_CONSTRAINED)
else:
msg = encode(Protocol.SSH_AGENTC_ADD_IDENTITY)
msg += encode(key_msg)
self.send(msg)
return True
def list(self) -> KeyList:
req = encode(Protocol.SSH_AGENTC_REQUEST_IDENTITIES)
r = memoryview(bytearray(self.send(req)))
if r[0] != Protocol.SSH_AGENT_IDENTITIES_ANSWER:
raise SshAgentFailure(
'agent: non-identities answer received for identities list'
)
return t.cast(KeyList, decode_dataclass(r[1:], KeyList))
def lock(self, passphrase: bytes):
msg = encode(Protocol.SSH_AGENTC_LOCK)
msg += encode(AgentLockMsg(passphrase))
self.send(msg)
return True
def unlock(self, passphrase: bytes):
msg = encode(Protocol.SSH_AGENTC_UNLOCK)
msg += encode(AgentLockMsg(passphrase))
self.send(msg)
return True
def __contains__(self, public_key: CryptoPublicKey) -> bool:
msg = PublicKeyMsg.from_public_key(public_key)
for key in self.list().keys:
if key == msg:
return True
return False
def load_private_key(key_data: bytes, passphrase: bytes) -> CryptoPrivateKey:
try:
private_key = ssh.load_ssh_private_key(
key_data,
password=passphrase,
)
except ValueError:
# Old keys generated by ssh-agent may not adhere to the strict
# definition of what ``load_ssh_private_key`` expects, fall
# back to generic PEM private key loading
private_key = serialization.load_pem_private_key(
key_data,
password=passphrase,
) # type: CryptoPrivateKey # type: ignore[no-redef]
allowed_types = t.get_args(CryptoPrivateKey)
if not isinstance(private_key, allowed_types):
type_names = (o.__name__ for o in allowed_types)
raise ValueError(
f'key_data must be one of {", ".join(type_names)} not, '
f'{private_key.__class__.__name__}'
)
return private_key

@ -0,0 +1,28 @@
Copyright (c) Contributors to the Ansible project. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
SUCH DAMAGE.

@ -0,0 +1,52 @@
from __future__ import annotations
import os
from ansible.plugins.action import ActionBase
from ansible.utils.ssh_agent import Client
from ansible.module_utils.common.text.converters import to_bytes
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
results = super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
match self._task.args['action']:
case 'list':
return self.list()
case 'lock':
return self.lock(self._task.args['password'])
case 'unlock':
return self.unlock(self._task.args['password'])
case _:
return {'failed': True, 'msg': 'not implemented'}
def lock(self, password):
with Client(os.environ['SSH_AUTH_SOCK']) as client:
client.lock(to_bytes(password))
return {'changed': True}
def unlock(self, password):
with Client(os.environ['SSH_AUTH_SOCK']) as client:
client.unlock(to_bytes(password))
return {'changed': True}
def list(self):
result = {'keys': [], 'nkeys': 0}
with Client(os.environ['SSH_AUTH_SOCK']) as client:
key_list = client.list()
result['nkeys'] = key_list.nkeys
for key in key_list.keys:
public_key = key.public_key()
key_size = getattr(public_key, 'key_size', 256)
fingerprint = key.fingerprint()
key_type = key.type.main_type
result['keys'].append({
'type': key_type,
'key_size': key_size,
'fingerprint': f'SHA256:{fingerprint}',
'comments': key.comments,
})
return result

@ -0,0 +1,48 @@
from __future__ import annotations
from ansible.plugins.action import ActionBase
from ansible.utils.ssh_agent import PublicKeyMsg
from ansible.module_utils.common.text.converters import to_bytes, to_text
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
results = super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
match self._task.args.get('type'):
case 'ed25519':
private_key = Ed25519PrivateKey.generate()
case 'rsa':
private_key = generate_private_key(65537, 4096)
case _:
return {'failed': True, 'msg': 'not implemented'}
public_key = private_key.public_key()
public_key_msg = PublicKeyMsg.from_public_key(public_key)
if not (passphrase := self._task.args.get('passphrase')):
encryption_algorithm = serialization.NoEncryption()
else:
encryption_algorithm = serialization.BestAvailableEncryption(
to_bytes(passphrase)
)
return {
'changed': True,
'private_key': to_text(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.OpenSSH,
encryption_algorithm=encryption_algorithm,
)),
'public_key': to_text(public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)),
'fingerprint': f'SHA256:{public_key_msg.fingerprint()}',
}

@ -0,0 +1,3 @@
needs/ssh
shippable/posix/group2
context/target

@ -0,0 +1,65 @@
- hosts: localhost
gather_facts: false
tasks:
- ssh_keygen:
type: ed25519
passphrase: passphrase
register: sshkey
- delegate_to: testhost
block:
- slurp:
path: ~/.ssh/authorized_keys
register: akeys
- debug:
msg: '{{ akeys.content|b64decode }}'
- copy:
content: |
{{ sshkey.public_key }}
{{ akeys.content|b64decode }}
dest: ~/.ssh/authorized_keys
mode: '0400'
- add_host:
name: testhost
ansible_password: ~
ansible_ssh_password: ~
ansible_ssh_private_key_file: ~
ansible_private_key: '{{ sshkey.private_key }}'
ansible_private_key_passphrase: passphrase
fingerprint: '{{ sshkey.fingerprint }}'
- hosts: testhost
gather_facts: false
tasks:
- ping:
- name: list keys from agent
ssh_agent:
action: list
register: keys
- assert:
that:
- keys.nkeys == 1
- keys['keys'][0].fingerprint == fingerprint
- name: lock the agent
ssh_agent:
action: lock
password: pancakes
- name: this will fail because the agent is locked
ping:
ignore_errors: true
register: _
failed_when: _ is not failed
- name: unlock the agent
ssh_agent:
action: unlock
password: pancakes
- ping:

@ -0,0 +1,23 @@
- delegate_to: localhost
block:
- name: install bcrypt
pip:
name: bcrypt
register: bcrypt
- tempfile:
path: "{{ lookup('env', 'OUTPUT_DIR') }}"
state: directory
register: tmpdir
- import_tasks: tests.yml
always:
- name: uninstall bcrypt
pip:
name: bcrypt
state: absent
when: bcrypt is changed
- file:
path: tmpdir.path
state: absent

@ -0,0 +1,49 @@
- slurp:
path: ~/.ssh/authorized_keys
register: akeys
- debug:
msg: '{{ akeys.content|b64decode }}'
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
environment:
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
ANSIBLE_SSH_AGENT: auto
register: auto
- command: ps {{ ps_flags }} -opid
register: pids
# Some distros will exit with rc=1 if no processes were returned
vars:
ps_flags: '{{ "" if ansible_distribution == "Alpine" else "-x" }}'
- assert:
that:
- >-
'started and bound to' in auto.stdout
- >-
'SSH: SSH_AGENT adding' in auto.stdout
- >-
'exists in agent' in auto.stdout
- pids|map('trim')|select('eq', pid) == []
vars:
pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}'
- command: ssh-agent -D -s -a '{{ tmpdir.path }}/agent.sock'
async: 30
poll: 0
- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml
environment:
ANSIBLE_CALLBACK_RESULT_FORMAT: yaml
ANSIBLE_SSH_AGENT: '{{ tmpdir.path }}/agent.sock'
register: existing
- assert:
that:
- >-
'started and bound to' not in existing.stdout
- >-
'SSH: SSH_AGENT adding' in existing.stdout
- >-
'exists in agent' in existing.stdout
Loading…
Cancel
Save