From 79ab013f2677ef765569a20a1dd061b431491809 Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Wed, 8 Nov 2023 15:29:56 -0600 Subject: [PATCH] Add ssh-agent launching, and ssh-agent python client --- changelogs/fragments/ssh-agent.yml | 5 + lib/ansible/config/base.yml | 18 + lib/ansible/inventory/manager.py | 50 ++ lib/ansible/plugins/connection/ssh.py | 89 ++- lib/ansible/utils/path.py | 4 +- lib/ansible/utils/ssh_agent.py | 645 ++++++++++++++++++ licenses/BSD-3-Clause.txt | 28 + .../ssh_agent/action_plugins/ssh_agent.py | 52 ++ .../ssh_agent/action_plugins/ssh_keygen.py | 48 ++ test/integration/targets/ssh_agent/aliases | 3 + test/integration/targets/ssh_agent/auto.yml | 65 ++ .../targets/ssh_agent/tasks/main.yml | 23 + .../targets/ssh_agent/tasks/tests.yml | 49 ++ 13 files changed, 1074 insertions(+), 5 deletions(-) create mode 100644 changelogs/fragments/ssh-agent.yml create mode 100644 lib/ansible/utils/ssh_agent.py create mode 100644 licenses/BSD-3-Clause.txt create mode 100644 test/integration/targets/ssh_agent/action_plugins/ssh_agent.py create mode 100644 test/integration/targets/ssh_agent/action_plugins/ssh_keygen.py create mode 100644 test/integration/targets/ssh_agent/aliases create mode 100644 test/integration/targets/ssh_agent/auto.yml create mode 100644 test/integration/targets/ssh_agent/tasks/main.yml create mode 100644 test/integration/targets/ssh_agent/tasks/tests.yml diff --git a/changelogs/fragments/ssh-agent.yml b/changelogs/fragments/ssh-agent.yml new file mode 100644 index 00000000000..c87deb93e24 --- /dev/null +++ b/changelogs/fragments/ssh-agent.yml @@ -0,0 +1,5 @@ +minor_changes: +- ssh-agent - InventoryManager is capable of spawning or reusing an ssh-agent, allowing plugins to interact with the ssh-agent. + Additionally a pure python ssh-client has been added, enabling easy interaction with the agent. The ssh connection plugin contains + new functionality via ``ansible_ssh_private_key`` and ``ansible_ssh_private_key_passphrase``, for loading an SSH private key into + the agent from a variable. diff --git a/lib/ansible/config/base.yml b/lib/ansible/config/base.yml index 24f9464d0a3..db5f7d93755 100644 --- a/lib/ansible/config/base.yml +++ b/lib/ansible/config/base.yml @@ -1897,6 +1897,24 @@ SHOW_CUSTOM_STATS: ini: - {key: show_custom_stats, section: defaults} type: bool +SSH_AGENT: + name: Manage an SSH Agent + description: Manage an SSH Agent via Ansible. A configuration of ``none`` will not interact with an agent, + ``auto`` will start and destroy an agent during the run, and a path to an SSH_AUTH_SOCK will + allow interaction with a pre-existing agent. + default: none + type: string + env: [{name: ANSIBLE_SSH_AGENT}] + ini: [{key: ssh_agent, section: connection}] + version_added: '2.17' +SSH_AGENT_KEY_LIFETIME: + name: Set a maximum lifetime when adding identities to an agent + description: For keys inserted into an agent defined by ``SSH_AGENT``, define a lifetime, in seconds, that the key may remain + in the agent. + type: int + env: [{name: ANSIBLE_SSH_AGENT_KEY_LIFETIME}] + ini: [{key: ssh_agent_key_lifetime, section: connection}] + version_added: '2.17' STRING_TYPE_FILTERS: name: Filters to preserve strings default: [string, to_json, to_nice_json, to_yaml, to_nice_yaml, ppretty, json] diff --git a/lib/ansible/inventory/manager.py b/lib/ansible/inventory/manager.py index ba6397f1787..0323068b00a 100644 --- a/lib/ansible/inventory/manager.py +++ b/lib/ansible/inventory/manager.py @@ -18,11 +18,13 @@ ############################################# from __future__ import annotations +import atexit import fnmatch import os import sys import re import itertools +import subprocess import traceback from operator import attrgetter @@ -32,12 +34,14 @@ from ansible import constants as C from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError from ansible.inventory.data import InventoryData from ansible.module_utils.six import string_types +from ansible.module_utils.common.process import get_bin_path from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.parsing.utils.addresses import parse_address from ansible.plugins.loader import inventory_loader from ansible.utils.helpers import deduplicate_list from ansible.utils.path import unfrackpath from ansible.utils.display import Display +from ansible.utils.ssh_agent import Client from ansible.utils.vars import combine_vars from ansible.vars.plugins import get_vars_from_inventory_sources @@ -161,6 +165,8 @@ class InventoryManager(object): else: self._sources = sources + self._launch_ssh_agent() + # get to work! if parse: self.parse_sources(cache=cache) @@ -168,6 +174,50 @@ class InventoryManager(object): self._cached_dynamic_hosts = [] self._cached_dynamic_grouping = [] + def _launch_ssh_agent(self): + ssh_agent_cfg = C.config.get_config_value('SSH_AGENT') + match ssh_agent_cfg: + case 'none': + return + case 'auto': + ssh_agent_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'ssh_agent') + os.mkdir(ssh_agent_dir, 0o700) + try: + ssh_agent_bin = get_bin_path('ssh-agent', required=True) + except ValueError as e: + raise AnsibleError('SSH_AGENT set to auto, but cannot find ssh-agent binary') from e + sock = os.path.join(ssh_agent_dir, 'agent.sock') + p = subprocess.Popen( + [ssh_agent_bin, '-D', '-s', '-a', sock], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if p.poll() is not None: + raise AnsibleError( + f'Could not start ssh-agent: (rc={p.returncode}) {p.stderr}' + ) + if (stdout := p.stdout.read(13)) != b'SSH_AUTH_SOCK': + display.warn( + f'The first 13 characters of stdout did not match the ' + f'expected SSH_AUTH_SOCK. This may not be the right binary, ' + f'or an incompatible agent: {stdout.decode()}' + ) + display.vvv(f'SSH_AGENT: ssh-agent[{p.pid}] started and bound to {sock}') + atexit.register(p.terminate) + case _: + sock = ssh_agent_cfg + + try: + with Client(sock) as client: + client.list() + except Exception as e: + raise AnsibleError( + f'Could not communicate with ssh-agent using auth sock {sock}: {e}' + ) + + os.environ['SSH_AUTH_SOCK'] = os.environ['ANSIBLE_SSH_AGENT'] = sock + @property def localhost(self): return self._inventory.get_host('localhost') diff --git a/lib/ansible/plugins/connection/ssh.py b/lib/ansible/plugins/connection/ssh.py index 299039faa5b..89c891ff0b6 100644 --- a/lib/ansible/plugins/connection/ssh.py +++ b/lib/ansible/plugins/connection/ssh.py @@ -264,7 +264,30 @@ DOCUMENTATION = """ cli: - name: private_key_file option: '--private-key' - + private_key: + description: + - private key contents in PEM format. Requires the SSH_AGENT configuration to be enabled. + type: string + ini: + - section: defaults + key: private_key + env: + - name: ANSIBLE_PRIVATE_KEY + vars: + - name: ansible_private_key + version_added: '2.17' + private_key_passphrase: + description: + - private key passphrase, dependent on ``private_key``. + type: string + ini: + - section: defaults + key: private_key_passphrase + env: + - name: ANSIBLE_PRIVATE_KEY_PASSPHRASE + vars: + - name: ansible_private_key_passphrase + version_added: '2.17' control_path: description: - This is the location to save SSH's ControlPath sockets, it uses SSH's variable substitution. @@ -380,6 +403,8 @@ import time import typing as t from functools import wraps + +from ansible import constants as C from ansible.errors import ( AnsibleAuthenticationFailure, AnsibleConnectionFailure, @@ -392,6 +417,15 @@ from ansible.plugins.connection import ConnectionBase, BUFSIZE from ansible.plugins.shell.powershell import _parse_clixml from ansible.utils.display import Display from ansible.utils.path import unfrackpath, makedirs_safe +from ansible.utils.ssh_agent import load_private_key, Client, PublicKeyMsg + +try: + from cryptography.hazmat.primitives import serialization +except ImportError: + HAS_CRYPTOGRAPHY = False +else: + HAS_CRYPTOGRAPHY = True + display = Display() @@ -583,6 +617,8 @@ class Connection(ConnectionBase): self.module_implementation_preferences = ('.ps1', '.exe', '') self.allow_executable = False + self._populated_agent = None + # The connection is created by running ssh/scp/sftp from the exec_command, # put_file, and fetch_file methods, so we don't need to do any connection # management here. @@ -664,6 +700,49 @@ class Connection(ConnectionBase): display.vvvvv(u'SSH: %s: (%s)' % (explanation, ')('.join(to_text(a) for a in b_args)), host=self.host) b_command += b_args + def _populate_agent(self): + if self._populated_agent: + return self._populated_agent + + if (auth_sock := C.config.get_config_value('SSH_AGENT')) == 'none': + raise AnsibleError('Cannot utilize private_key with SSH_AGENT disabled') + + key_data = self.get_option('private_key') + passphrase = self.get_option('private_key_passphrase') + private_key = load_private_key( + to_bytes(key_data), + to_bytes(passphrase) if passphrase else None + ) + public_key = private_key.public_key() + pubkey_msg = PublicKeyMsg.from_public_key(public_key) + fingerprint = pubkey_msg.fingerprint() + with Client(auth_sock) as client: + if (public_key := private_key.public_key()) not in client: + display.vvv(f'SSH: SSH_AGENT adding {fingerprint} to agent', host=self.host) + client.add( + private_key, + '[added by ansible]', + C.config.get_config_value('SSH_AGENT_KEY_LIFETIME'), + ) + else: + display.vvv(f'SSH: SSH_AGENT {fingerprint} exists in agent', host=self.host) + # Write the public key to disk, to be provided as IdentityFile. + # This allows ssh to pick an explicit key in the agent to use, + # preventing ssh from attempting all keys in the agent. + pubkey_path = self._populated_agent = os.path.join( + C.DEFAULT_LOCAL_TMP, + fingerprint.replace('/', '-') + '.pub' + ) + if os.path.exists(pubkey_path): + return pubkey_path + fd = os.open(pubkey_path, os.O_CREAT | os.O_WRONLY, mode=0o400) + with os.fdopen(fd, 'wb') as f: + f.write(public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH + )) + return self._populated_agent + def _build_command(self, binary: str, subsystem: str, *other_args: bytes | str) -> list[bytes]: """ Takes a executable (ssh, scp, sftp or wrapper) and optional extra arguments and returns the remote command @@ -748,8 +827,12 @@ class Connection(ConnectionBase): b_args = (b"-o", b"Port=" + to_bytes(self.port, nonstring='simplerepr', errors='surrogate_or_strict')) self._add_args(b_command, b_args, u"ANSIBLE_REMOTE_PORT/remote_port/ansible_port set") - key = self.get_option('private_key_file') - if key: + if self.get_option('private_key'): + key = self._populate_agent() + b_args = (b'-o', b'IdentitiesOnly=yes', b'-o', b'IdentityFile="' + to_bytes(key, errors='surrogate_or_strict') + b'"') + self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY/private_key set") + + elif (key := self.get_option('private_key_file')): b_args = (b"-o", b'IdentityFile="' + to_bytes(os.path.expanduser(key), errors='surrogate_or_strict') + b'"') self._add_args(b_command, b_args, u"ANSIBLE_PRIVATE_KEY_FILE/private_key_file/ansible_ssh_private_key_file set") diff --git a/lib/ansible/utils/path.py b/lib/ansible/utils/path.py index 202a4f42592..fe6c29dc97d 100644 --- a/lib/ansible/utils/path.py +++ b/lib/ansible/utils/path.py @@ -19,7 +19,7 @@ from __future__ import annotations import os import shutil -from errno import EEXIST +from errno import EEXIST, ENOENT from ansible.errors import AnsibleError from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text @@ -127,7 +127,7 @@ def cleanup_tmp_file(path, warn=False): elif os.path.isfile(path): os.unlink(path) except Exception as e: - if warn: + if warn and getattr(e, 'errno', None) != ENOENT: # Importing here to avoid circular import from ansible.utils.display import Display display = Display() diff --git a/lib/ansible/utils/ssh_agent.py b/lib/ansible/utils/ssh_agent.py new file mode 100644 index 00000000000..b2ae2455f8f --- /dev/null +++ b/lib/ansible/utils/ssh_agent.py @@ -0,0 +1,645 @@ +# Copyright: Contributors to the Ansible project +# BSD 3 Clause License (see licenses/BSD-3-Clause.txt or https://opensource.org/license/bsd-3-clause/) + +from __future__ import annotations + +import binascii +import collections.abc +import copy +import dataclasses +import enum +import hashlib +import socket +import typing as t + +try: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.dsa import ( + DSAParameterNumbers, + DSAPrivateKey, + DSAPublicKey, + DSAPublicNumbers, + ) + from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurve, + EllipticCurvePrivateKey, + EllipticCurvePublicKey, + SECP256R1, + SECP384R1, + SECP521R1, + ) + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, + RSAPublicNumbers, + ) + from cryptography.hazmat.primitives.serialization import ssh + + CryptoPublicKey = t.Union[ + DSAPublicKey, + EllipticCurvePublicKey, + Ed25519PublicKey, + RSAPublicKey, + ] + + CryptoPrivateKey = t.Union[ + DSAPrivateKey, + EllipticCurvePrivateKey, + Ed25519PrivateKey, + RSAPrivateKey, + ] +except ImportError: + HAS_CRYPTOGRAPHY = False +else: + HAS_CRYPTOGRAPHY = True + + +class SshAgentFailure(Exception): + ... + + +class mpint(int): + ... + + +class byte(int): + ... + + +class constraints(bytes): + ... + + +class Protocol(enum.IntEnum): + # Responses + SSH_AGENT_FAILURE = 5 + SSH_AGENT_SUCCESS = 6 + SSH_AGENT_EXTENSION_FAILURE = 28 + SSH_AGENT_IDENTITIES_ANSWER = 12 + SSH_AGENT_SIGN_RESPONSE = 14 + + # Constraints + SSH_AGENT_CONSTRAIN_LIFETIME = 1 + SSH_AGENT_CONSTRAIN_CONFIRM = 2 + SSH_AGENT_CONSTRAIN_EXTENSION = 255 + + # Requests + SSH_AGENTC_REQUEST_IDENTITIES = 11 + SSH_AGENTC_SIGN_REQUEST = 13 + SSH_AGENTC_ADD_IDENTITY = 17 + SSH_AGENTC_REMOVE_IDENTITY = 18 + SSH_AGENTC_REMOVE_ALL_IDENTITIES = 19 + SSH_AGENTC_ADD_SMARTCARD_KEY = 20 + SSH_AGENTC_REMOVE_SMARTCARD_KEY = 21 + SSH_AGENTC_LOCK = 22 + SSH_AGENTC_UNLOCK = 23 + SSH_AGENTC_ADD_ID_CONSTRAINED = 25 + SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED = 26 + SSH_AGENTC_EXTENSION = 27 + + +class KeyAlgo(str, enum.Enum): + RSA = "ssh-rsa" + DSA = "ssh-dss" + ECDSA256 = "ecdsa-sha2-nistp256" + SKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com" + ECDSA384 = "ecdsa-sha2-nistp384" + ECDSA521 = "ecdsa-sha2-nistp521" + ED25519 = "ssh-ed25519" + SKED25519 = "sk-ssh-ed25519@openssh.com" + RSASHA256 = "rsa-sha2-256" + RSASHA512 = "rsa-sha2-512" + + @property + def main_type(self): + match self: + case self.RSA: + return 'RSA' + case self.DSA: + return 'DSA' + case self.ECDSA256 | self.ECDSA384 | self.ECDSA521: + return 'ECDSA' + case self.ED25519: + return 'ED25519' + case _: + raise NotImplementedError(self.name) + + +if HAS_CRYPTOGRAPHY: + _ECDSA_KEY_TYPE: dict[KeyAlgo, type[EllipticCurve]] = { + KeyAlgo.ECDSA256: SECP256R1, + KeyAlgo.ECDSA384: SECP384R1, + KeyAlgo.ECDSA521: SECP521R1, + } + + +@dataclasses.dataclass +class Msg: + ... + + +@dataclasses.dataclass(order=True, slots=True) +class AgentLockMsg(Msg): + passphrase: bytes + + +@dataclasses.dataclass +class PrivateKeyMsg(Msg): + @staticmethod + def from_private_key(private_key): + match private_key: + case RSAPrivateKey(): + pn = private_key.private_numbers() + return RSAPrivateKeyMsg( + KeyAlgo.RSA, + pn.public_numbers.n, + pn.public_numbers.e, + pn.d, + pn.iqmp, + pn.p, + pn.q, + ) + case DSAPrivateKey(): + pn = private_key.private_numbers() + return DSAPrivateKeyMsg( + KeyAlgo.DSA, + pn.public_numbers.parameter_numbers.p, + pn.public_numbers.parameter_numbers.q, + pn.public_numbers.parameter_numbers.g, + pn.public_numbers.y, + pn.x, + ) + case EllipticCurvePrivateKey(): + pn = private_key.private_numbers() + key_size = private_key.key_size + return EcdsaPrivateKeyMsg( + getattr(KeyAlgo, f'ECDSA{key_size}'), + f'nistp{key_size}', + private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint + ), + pn.private_value, + ) + case Ed25519PrivateKey(): + public_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + private_bytes = private_key.private_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PrivateFormat.Raw, + encryption_algorithm=serialization.NoEncryption() + ) + return Ed25519PrivateKeyMsg( + KeyAlgo.ED25519, + public_bytes, + private_bytes + public_bytes, + ) + case _: + raise NotImplementedError(private_key) + + +@dataclasses.dataclass(order=True, slots=True) +class RSAPrivateKeyMsg(PrivateKeyMsg): + type: KeyAlgo + n: mpint + e: mpint + d: mpint + iqmp: mpint + p: mpint + q: mpint + comments: str = dataclasses.field(default='', compare=False) + constraints: constraints = dataclasses.field(default=constraints(b'')) + + +@dataclasses.dataclass(order=True, slots=True) +class DSAPrivateKeyMsg(PrivateKeyMsg): + type: KeyAlgo + p: mpint + q: mpint + g: mpint + y: mpint + x: mpint + comments: str = dataclasses.field(default='', compare=False) + constraints: constraints = dataclasses.field(default=constraints(b'')) + + +@dataclasses.dataclass(order=True, slots=True) +class EcdsaPrivateKeyMsg(PrivateKeyMsg): + type: KeyAlgo + ecdsa_curve_name: str + Q: bytes + d: mpint + comments: str = dataclasses.field(default='', compare=False) + constraints: constraints = dataclasses.field(default=constraints(b'')) + + +@dataclasses.dataclass(order=True, slots=True) +class Ed25519PrivateKeyMsg(PrivateKeyMsg): + type: KeyAlgo + enc_a: bytes + k_env_a: bytes + comments: str = dataclasses.field(default='', compare=False) + constraints: constraints = dataclasses.field(default=constraints(b'')) + + +@dataclasses.dataclass +class PublicKeyMsg(Msg): + @staticmethod + def get_dataclass( + type: KeyAlgo + ) -> type[t.Union[ + RSAPublicKeyMsg, + EcdsaPublicKeyMsg, + Ed25519PublicKeyMsg, + DSAPublicKeyMsg + ]]: + match type: + case KeyAlgo.RSA: + return RSAPublicKeyMsg + case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521: + return EcdsaPublicKeyMsg + case KeyAlgo.ED25519: + return Ed25519PublicKeyMsg + case KeyAlgo.DSA: + return DSAPublicKeyMsg + case _: + raise NotImplementedError(type) + + def public_key(self) -> CryptoPublicKey: + type = self.type # type: ignore[attr-defined] + match type: + case KeyAlgo.RSA: + return RSAPublicNumbers( + self.e, # type: ignore[attr-defined] + self.n # type: ignore[attr-defined] + ).public_key() + case KeyAlgo.ECDSA256 | KeyAlgo.ECDSA384 | KeyAlgo.ECDSA521: + curve = _ECDSA_KEY_TYPE[KeyAlgo(type)] + return EllipticCurvePublicKey.from_encoded_point( + curve(), + self.Q # type: ignore[attr-defined] + ) + case KeyAlgo.ED25519: + return Ed25519PublicKey.from_public_bytes( + self.enc_a # type: ignore[attr-defined] + ) + case KeyAlgo.DSA: + return DSAPublicNumbers( + self.y, # type: ignore[attr-defined] + DSAParameterNumbers( + self.p, # type: ignore[attr-defined] + self.q, # type: ignore[attr-defined] + self.g # type: ignore[attr-defined] + ) + ).public_key() + case _: + raise NotImplementedError(type) + + @staticmethod + def from_public_key(public_key): + match public_key: + case DSAPublicKey(): + pn = public_key.public_numbers() + return DSAPublicKeyMsg( + KeyAlgo.DSA, + pn.parameter_numbers.p, + pn.parameter_numbers.q, + pn.parameter_numbers.g, + pn.y + ) + case EllipticCurvePublicKey(): + return EcdsaPublicKeyMsg( + getattr(KeyAlgo, f'ECDSA{public_key.curve.key_size}'), + f'nistp{public_key.curve.key_size}', + public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint + ) + ) + case Ed25519PublicKey(): + return Ed25519PublicKeyMsg( + KeyAlgo.ED25519, + public_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + ) + case RSAPublicKey(): + pn = public_key.public_numbers() + return RSAPublicKeyMsg( + KeyAlgo.RSA, + pn.e, + pn.n + ) + case _: + raise NotImplementedError(public_key) + + def fingerprint(self): + digest = hashlib.sha256() + msg = copy.copy(self) + msg.comments = '' + k = encode(msg) + digest.update(k) + return binascii.b2a_base64( + digest.digest(), + newline=False + ).rstrip(b'=').decode('utf-8') + + +@dataclasses.dataclass(order=True, slots=True) +class RSAPublicKeyMsg(PublicKeyMsg): + type: KeyAlgo + e: mpint + n: mpint + comments: str = dataclasses.field(default='', compare=False) + + +@dataclasses.dataclass(order=True, slots=True) +class DSAPublicKeyMsg(PublicKeyMsg): + type: KeyAlgo + p: mpint + q: mpint + g: mpint + y: mpint + comments: str = dataclasses.field(default='', compare=False) + + +@dataclasses.dataclass(order=True, slots=True) +class EcdsaPublicKeyMsg(PublicKeyMsg): + type: KeyAlgo + ecdsa_curve_name: str + Q: bytes + comments: str = dataclasses.field(default='', compare=False) + + +@dataclasses.dataclass(order=True, slots=True) +class Ed25519PublicKeyMsg(PublicKeyMsg): + type: KeyAlgo + enc_a: bytes + comments: str = dataclasses.field(default='', compare=False) + + +@dataclasses.dataclass(order=True, slots=True) +class KeyList(Msg): + nkeys: int + keys: list[PublicKeyMsg] + + def __init__(self, nkeys, *args): + self.nkeys = nkeys + self.keys = args + + +def _to_bytes(val: int, length: int) -> bytes: + return val.to_bytes(length=length, byteorder='big') + + +def _from_bytes(val: bytes) -> int: + return int.from_bytes(val, byteorder='big') + + +def _to_mpint(val: int) -> bytes: + if val < 0: + raise ValueError("negative mpint not allowed") + if not val: + return b"" + nbytes = (val.bit_length() + 8) // 8 + ret = bytearray(_to_bytes(val, nbytes)) + ret[:0] = _to_bytes(len(ret), 4) + return ret + + +def _from_mpint(val: bytes) -> int: + if val and val[0] > 127: + raise ValueError("Invalid data") + return _from_bytes(val) + + +def encode_dataclass(msg: Msg) -> collections.abc.Iterator[bytes]: + for field in dataclasses.fields(msg): + fv = getattr(msg, field.name) + match field.type: + case 'mpint': + yield _to_mpint(fv) + case 'int': # uint32 + yield _to_bytes(fv, 4) + case 'bool' | 'byte' | 'Protocol': + yield _to_bytes(fv, 1) + case 'str' | 'KeyAlgo': + if fv: + fv = fv.encode('utf-8') + yield _to_bytes(len(fv), 4) + yield fv + case 'bytes': + if fv: + yield _to_bytes(len(fv), 4) + yield fv + case 'constraints': + yield fv + case _: + raise NotImplementedError(field.type) + + +def parse_annotation(type: t.Any) -> tuple[str, str]: + if type.count('[') > 1: + raise NotImplementedError() + main, _sep, sub = type.removesuffix(']').partition('[') + return main, sub + + +def _consume_field( + blob: memoryview, + type: t.Any | None = None +) -> tuple[memoryview, int, memoryview]: + match type: + case 'int': + length = 4 + case 'bool' | 'byte' | 'Protocol': + length = 1 + case _: + length = _from_bytes(blob[:4]) + blob = blob[4:] + return blob[:length], length, blob[length:] + + +def decode_dataclass(blob: memoryview, dataclass: type[Msg]) -> Msg: + fi = 0 + args: list[t.Any] = [] + fields = dataclasses.fields(dataclass) + while blob: + field = fields[fi] + prev_blob = blob + fv, length, blob = _consume_field(blob, type=field.type) + main_type, sub_type = parse_annotation(field.type) + match main_type: + case 'mpint': + args.append(_from_mpint(fv)) + case 'int': # uint32 + args.append(_from_bytes(fv)) + case 'Protocol': + args.append(Protocol(_from_bytes(fv))) + case 'bool' | 'byte': + args.append(_from_bytes(fv)) + case 'KeyAlgo': + args.append(KeyAlgo(fv.tobytes().decode('utf-8'))) + case 'str': + args.append(fv.tobytes().decode('utf-8')) + case 'bytes': + args.append(bytes(fv)) + case 'list': + # Lists should always be last + match sub_type: + case 'PublicKeyMsg': + peek, _length, _blob = _consume_field(fv) + sub = PublicKeyMsg.get_dataclass( + KeyAlgo(peek.tobytes().decode('utf-8')) + ) + + _fv, cl, blob = _consume_field(blob) + key_plus_comment = ( + prev_blob[4:length + cl + 8] + ) + case _: + raise NotImplementedError(sub_type) + + args.append(decode_dataclass(key_plus_comment, sub)) + fi -= 1 # We are in a list, don't move to the next field + case _: + raise NotImplementedError(field.type) + + fi += 1 + return dataclass(*args) + + +def encode(msg: Protocol | Msg) -> bytes: + if isinstance(msg, Protocol): + payload = bytes([msg]) + else: + payload = b''.join(encode_dataclass(msg)) + return payload + + +class Client: + def __init__(self, auth_sock: str): + self._auth_sock = auth_sock + self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._sock.connect(auth_sock) + + def terminate(self): + self._ssh_agent.terminate() + + def close(self): + self._sock.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def send(self, msg: bytes) -> bytes: + length = _to_bytes(len(msg), 4) + self._sock.sendall(length + msg) + bufsize = _from_bytes(self._sock.recv(4)) + resp = self._sock.recv(bufsize) + if resp[0] == Protocol.SSH_AGENT_FAILURE: + raise SshAgentFailure('agent: failure') + return resp + + def remove_all(self): + msg = encode(Protocol.SSH_AGENTC_REMOVE_ALL_IDENTITIES) + self.send(msg) + return True + + def remove(self, public_key: CryptoPublicKey): + msg = encode(Protocol.SSH_AGENTC_REMOVE_IDENTITY) + key_blob = encode( + PublicKeyMsg.from_public_key(public_key) + ) + msg += _to_bytes(len(key_blob), 4) + msg += key_blob + self.send(msg) + return True + + def add( + self, + private_key: CryptoPrivateKey, + comments: str | None = None, + lifetime: int | None = None, + confirm: bool | None = None, + ): + key_msg = PrivateKeyMsg.from_private_key(private_key) + key_msg.comments = comments or '' + if lifetime: + key_msg.constraints += constraints( + [Protocol.SSH_AGENT_CONSTRAIN_LIFETIME] + ) + _to_bytes(lifetime, 4) + if confirm: + key_msg.constraints += constraints( + [Protocol.SSH_AGENT_CONSTRAIN_CONFIRM] + ) + + if key_msg.constraints: + msg = encode(Protocol.SSH_AGENTC_ADD_ID_CONSTRAINED) + else: + msg = encode(Protocol.SSH_AGENTC_ADD_IDENTITY) + msg += encode(key_msg) + self.send(msg) + return True + + def list(self) -> KeyList: + req = encode(Protocol.SSH_AGENTC_REQUEST_IDENTITIES) + r = memoryview(bytearray(self.send(req))) + if r[0] != Protocol.SSH_AGENT_IDENTITIES_ANSWER: + raise SshAgentFailure( + 'agent: non-identities answer received for identities list' + ) + return t.cast(KeyList, decode_dataclass(r[1:], KeyList)) + + def lock(self, passphrase: bytes): + msg = encode(Protocol.SSH_AGENTC_LOCK) + msg += encode(AgentLockMsg(passphrase)) + self.send(msg) + return True + + def unlock(self, passphrase: bytes): + msg = encode(Protocol.SSH_AGENTC_UNLOCK) + msg += encode(AgentLockMsg(passphrase)) + self.send(msg) + return True + + def __contains__(self, public_key: CryptoPublicKey) -> bool: + msg = PublicKeyMsg.from_public_key(public_key) + for key in self.list().keys: + if key == msg: + return True + return False + + +def load_private_key(key_data: bytes, passphrase: bytes) -> CryptoPrivateKey: + try: + private_key = ssh.load_ssh_private_key( + key_data, + password=passphrase, + ) + except ValueError: + # Old keys generated by ssh-agent may not adhere to the strict + # definition of what ``load_ssh_private_key`` expects, fall + # back to generic PEM private key loading + private_key = serialization.load_pem_private_key( + key_data, + password=passphrase, + ) # type: CryptoPrivateKey # type: ignore[no-redef] + allowed_types = t.get_args(CryptoPrivateKey) + if not isinstance(private_key, allowed_types): + type_names = (o.__name__ for o in allowed_types) + raise ValueError( + f'key_data must be one of {", ".join(type_names)} not, ' + f'{private_key.__class__.__name__}' + ) + return private_key diff --git a/licenses/BSD-3-Clause.txt b/licenses/BSD-3-Clause.txt new file mode 100644 index 00000000000..0101e7b2a20 --- /dev/null +++ b/licenses/BSD-3-Clause.txt @@ -0,0 +1,28 @@ +Copyright (c) Contributors to the Ansible project. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +SUCH DAMAGE. diff --git a/test/integration/targets/ssh_agent/action_plugins/ssh_agent.py b/test/integration/targets/ssh_agent/action_plugins/ssh_agent.py new file mode 100644 index 00000000000..f41ed699560 --- /dev/null +++ b/test/integration/targets/ssh_agent/action_plugins/ssh_agent.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import os + +from ansible.plugins.action import ActionBase +from ansible.utils.ssh_agent import Client +from ansible.module_utils.common.text.converters import to_bytes + + +class ActionModule(ActionBase): + + def run(self, tmp=None, task_vars=None): + results = super(ActionModule, self).run(tmp, task_vars) + del tmp # tmp no longer has any effect + match self._task.args['action']: + case 'list': + return self.list() + case 'lock': + return self.lock(self._task.args['password']) + case 'unlock': + return self.unlock(self._task.args['password']) + case _: + return {'failed': True, 'msg': 'not implemented'} + + def lock(self, password): + with Client(os.environ['SSH_AUTH_SOCK']) as client: + client.lock(to_bytes(password)) + return {'changed': True} + + def unlock(self, password): + with Client(os.environ['SSH_AUTH_SOCK']) as client: + client.unlock(to_bytes(password)) + return {'changed': True} + + def list(self): + result = {'keys': [], 'nkeys': 0} + with Client(os.environ['SSH_AUTH_SOCK']) as client: + key_list = client.list() + result['nkeys'] = key_list.nkeys + for key in key_list.keys: + public_key = key.public_key() + key_size = getattr(public_key, 'key_size', 256) + fingerprint = key.fingerprint() + key_type = key.type.main_type + result['keys'].append({ + 'type': key_type, + 'key_size': key_size, + 'fingerprint': f'SHA256:{fingerprint}', + 'comments': key.comments, + }) + + return result diff --git a/test/integration/targets/ssh_agent/action_plugins/ssh_keygen.py b/test/integration/targets/ssh_agent/action_plugins/ssh_keygen.py new file mode 100644 index 00000000000..9079a514f5d --- /dev/null +++ b/test/integration/targets/ssh_agent/action_plugins/ssh_keygen.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from ansible.plugins.action import ActionBase +from ansible.utils.ssh_agent import PublicKeyMsg +from ansible.module_utils.common.text.converters import to_bytes, to_text + + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + + +class ActionModule(ActionBase): + + def run(self, tmp=None, task_vars=None): + results = super(ActionModule, self).run(tmp, task_vars) + del tmp # tmp no longer has any effect + match self._task.args.get('type'): + case 'ed25519': + private_key = Ed25519PrivateKey.generate() + case 'rsa': + private_key = generate_private_key(65537, 4096) + case _: + return {'failed': True, 'msg': 'not implemented'} + + public_key = private_key.public_key() + public_key_msg = PublicKeyMsg.from_public_key(public_key) + + if not (passphrase := self._task.args.get('passphrase')): + encryption_algorithm = serialization.NoEncryption() + else: + encryption_algorithm = serialization.BestAvailableEncryption( + to_bytes(passphrase) + ) + + return { + 'changed': True, + 'private_key': to_text(private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.OpenSSH, + encryption_algorithm=encryption_algorithm, + )), + 'public_key': to_text(public_key.public_bytes( + encoding=serialization.Encoding.OpenSSH, + format=serialization.PublicFormat.OpenSSH, + )), + 'fingerprint': f'SHA256:{public_key_msg.fingerprint()}', + } diff --git a/test/integration/targets/ssh_agent/aliases b/test/integration/targets/ssh_agent/aliases new file mode 100644 index 00000000000..dba9e76e597 --- /dev/null +++ b/test/integration/targets/ssh_agent/aliases @@ -0,0 +1,3 @@ +needs/ssh +shippable/posix/group2 +context/target diff --git a/test/integration/targets/ssh_agent/auto.yml b/test/integration/targets/ssh_agent/auto.yml new file mode 100644 index 00000000000..c73777e3d43 --- /dev/null +++ b/test/integration/targets/ssh_agent/auto.yml @@ -0,0 +1,65 @@ +- hosts: localhost + gather_facts: false + tasks: + - ssh_keygen: + type: ed25519 + passphrase: passphrase + register: sshkey + + - delegate_to: testhost + block: + - slurp: + path: ~/.ssh/authorized_keys + register: akeys + + - debug: + msg: '{{ akeys.content|b64decode }}' + + - copy: + content: | + {{ sshkey.public_key }} + {{ akeys.content|b64decode }} + dest: ~/.ssh/authorized_keys + mode: '0400' + + - add_host: + name: testhost + ansible_password: ~ + ansible_ssh_password: ~ + ansible_ssh_private_key_file: ~ + ansible_private_key: '{{ sshkey.private_key }}' + ansible_private_key_passphrase: passphrase + fingerprint: '{{ sshkey.fingerprint }}' + +- hosts: testhost + gather_facts: false + tasks: + - ping: + + - name: list keys from agent + ssh_agent: + action: list + register: keys + + - assert: + that: + - keys.nkeys == 1 + - keys['keys'][0].fingerprint == fingerprint + + - name: lock the agent + ssh_agent: + action: lock + password: pancakes + + - name: this will fail because the agent is locked + ping: + ignore_errors: true + register: _ + failed_when: _ is not failed + + - name: unlock the agent + ssh_agent: + action: unlock + password: pancakes + + - ping: diff --git a/test/integration/targets/ssh_agent/tasks/main.yml b/test/integration/targets/ssh_agent/tasks/main.yml new file mode 100644 index 00000000000..003407970c8 --- /dev/null +++ b/test/integration/targets/ssh_agent/tasks/main.yml @@ -0,0 +1,23 @@ +- delegate_to: localhost + block: + - name: install bcrypt + pip: + name: bcrypt + register: bcrypt + + - tempfile: + path: "{{ lookup('env', 'OUTPUT_DIR') }}" + state: directory + register: tmpdir + + - import_tasks: tests.yml + always: + - name: uninstall bcrypt + pip: + name: bcrypt + state: absent + when: bcrypt is changed + + - file: + path: tmpdir.path + state: absent diff --git a/test/integration/targets/ssh_agent/tasks/tests.yml b/test/integration/targets/ssh_agent/tasks/tests.yml new file mode 100644 index 00000000000..aad20d55025 --- /dev/null +++ b/test/integration/targets/ssh_agent/tasks/tests.yml @@ -0,0 +1,49 @@ +- slurp: + path: ~/.ssh/authorized_keys + register: akeys + +- debug: + msg: '{{ akeys.content|b64decode }}' + +- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml + environment: + ANSIBLE_CALLBACK_RESULT_FORMAT: yaml + ANSIBLE_SSH_AGENT: auto + register: auto + +- command: ps {{ ps_flags }} -opid + register: pids + # Some distros will exit with rc=1 if no processes were returned + vars: + ps_flags: '{{ "" if ansible_distribution == "Alpine" else "-x" }}' + +- assert: + that: + - >- + 'started and bound to' in auto.stdout + - >- + 'SSH: SSH_AGENT adding' in auto.stdout + - >- + 'exists in agent' in auto.stdout + - pids|map('trim')|select('eq', pid) == [] + vars: + pid: '{{ auto.stdout|regex_findall("ssh-agent\[(\d+)\]")|first }}' + +- command: ssh-agent -D -s -a '{{ tmpdir.path }}/agent.sock' + async: 30 + poll: 0 + +- command: ansible-playbook -i {{ ansible_inventory_sources|first|quote }} -vvv {{ role_path }}/auto.yml + environment: + ANSIBLE_CALLBACK_RESULT_FORMAT: yaml + ANSIBLE_SSH_AGENT: '{{ tmpdir.path }}/agent.sock' + register: existing + +- assert: + that: + - >- + 'started and bound to' not in existing.stdout + - >- + 'SSH: SSH_AGENT adding' in existing.stdout + - >- + 'exists in agent' in existing.stdout