From 1fe0fb0e7ad280114efb77d2323d829f0fc7bf30 Mon Sep 17 00:00:00 2001 From: Matt Clay Date: Tue, 10 Jun 2025 18:30:30 -0700 Subject: [PATCH] ansible-test - Code cleanup (#85297) (cherry picked from commit bdc6c8e16afe190680f45f2c063da0adb2c4130d) --- .../ansible_test/_internal/host_profiles.py | 72 ++++++++++++++++--- .../ansible_test/_internal/provisioning.py | 14 ++-- test/lib/ansible_test/_internal/ssh.py | 6 +- test/lib/ansible_test/_internal/thread.py | 3 +- test/lib/ansible_test/_internal/timeout.py | 2 +- test/lib/ansible_test/_internal/util.py | 32 +++++---- 6 files changed, 96 insertions(+), 33 deletions(-) diff --git a/test/lib/ansible_test/_internal/host_profiles.py b/test/lib/ansible_test/_internal/host_profiles.py index b726e13a0e9..5a7fe755d8b 100644 --- a/test/lib/ansible_test/_internal/host_profiles.py +++ b/test/lib/ansible_test/_internal/host_profiles.py @@ -206,7 +206,7 @@ class Inventory: inventory_text += f'[{group}]\n' for host, variables in hosts.items(): - kvp = ' '.join(f'{key}="{value}"' for key, value in variables.items()) + kvp = ' '.join(f"{key}={value!r}" for key, value in variables.items()) inventory_text += f'{host} {kvp}\n' inventory_text += '\n' @@ -235,18 +235,24 @@ class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta): *, args: EnvironmentConfig, config: THostConfig, - targets: t.Optional[list[HostConfig]], + controller: ControllerHostProfile, ) -> None: self.args = args self.config = config - self.controller = bool(targets) - self.targets = targets or [] + self.controller = not controller # this profile is a controller whenever the `controller` arg was not provided + self.targets = args.targets if self.controller else [] # only keep targets if this profile is a controller + self.controller_profile = controller if isinstance(self, ControllerProfile) else None self.state: dict[str, t.Any] = {} """State that must be persisted across delegation.""" self.cache: dict[str, t.Any] = {} """Cache that must not be persisted across delegation.""" + @property + @abc.abstractmethod + def name(self) -> str: + """The name of the host profile.""" + def provision(self) -> None: """Provision the host before delegation.""" @@ -274,6 +280,9 @@ class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta): # args will be populated after the instances are restored self.cache = {} + def __str__(self) -> str: + return f'{self.__class__.__name__}: {self.name}' + class PosixProfile(HostProfile[TPosixConfig], metaclass=abc.ABCMeta): """Base class for POSIX host profiles.""" @@ -320,6 +329,11 @@ class SshTargetHostProfile(HostProfile[THostConfig], metaclass=abc.ABCMeta): class RemoteProfile(SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta): """Base class for remote instance profiles.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return self.config.name + @property def core_ci_state(self) -> t.Optional[dict[str, str]]: """The saved Ansible Core CI state.""" @@ -339,6 +353,8 @@ class RemoteProfile(SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta): def deprovision(self) -> None: """Deprovision the host after delegation has completed.""" + super().deprovision() + if self.args.remote_terminate == TerminateMode.ALWAYS or (self.args.remote_terminate == TerminateMode.SUCCESS and self.args.success): self.delete_instance() @@ -397,6 +413,11 @@ class RemoteProfile(SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta): class ControllerProfile(SshTargetHostProfile[ControllerConfig], PosixProfile[ControllerConfig]): """Host profile for the controller as a target.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return self.controller_profile.name + def get_controller_target_connections(self) -> list[SshConnection]: """Return SSH connection(s) for accessing the host as a target from the controller.""" settings = SshConnectionDetail( @@ -425,6 +446,11 @@ class DockerProfile(ControllerHostProfile[DockerConfig], SshTargetHostProfile[Do command_privileged: bool expected_mounts: tuple[CGroupMount, ...] + @property + def name(self) -> str: + """The name of the host profile.""" + return self.config.name + @property def container_name(self) -> t.Optional[str]: """Return the stored container name, if any, otherwise None.""" @@ -976,6 +1002,8 @@ class DockerProfile(ControllerHostProfile[DockerConfig], SshTargetHostProfile[Do def deprovision(self) -> None: """Deprovision the host after delegation has completed.""" + super().deprovision() + container_exists = False if self.container_name: @@ -1025,10 +1053,10 @@ class DockerProfile(ControllerHostProfile[DockerConfig], SshTargetHostProfile[Do raise HostConnectionError(f'Timeout waiting for {self.config.name} container {self.container_name}.', callback) - def get_controller_target_connections(self) -> list[SshConnection]: - """Return SSH connection(s) for accessing the host as a target from the controller.""" + def get_ssh_connection_detail(self, host_type: str) -> SshConnectionDetail: + """Return SSH connection detail for the specified host type.""" containers = get_container_database(self.args) - access = containers.data[HostType.control]['__test_hosts__'][self.container_name] + access = containers.data[host_type]['__test_hosts__'][self.container_name] host = access.host_ip port = dict(access.port_map())[22] @@ -1046,7 +1074,11 @@ class DockerProfile(ControllerHostProfile[DockerConfig], SshTargetHostProfile[Do enable_rsa_sha1='centos6' in self.config.image, ) - return [SshConnection(self.args, settings)] + return settings + + def get_controller_target_connections(self) -> list[SshConnection]: + """Return SSH connection(s) for accessing the host as a target from the controller.""" + return [SshConnection(self.args, self.get_ssh_connection_detail(HostType.control))] def get_origin_controller_connection(self) -> DockerConnection: """Return a connection for accessing the host as a controller from the origin.""" @@ -1116,6 +1148,11 @@ class DockerProfile(ControllerHostProfile[DockerConfig], SshTargetHostProfile[Do class NetworkInventoryProfile(HostProfile[NetworkInventoryConfig]): """Host profile for a network inventory.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return self.config.path + class NetworkRemoteProfile(RemoteProfile[NetworkRemoteConfig]): """Host profile for a network remote instance.""" @@ -1197,6 +1234,11 @@ class NetworkRemoteProfile(RemoteProfile[NetworkRemoteConfig]): class OriginProfile(ControllerHostProfile[OriginConfig]): """Host profile for origin.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return 'origin' + def get_origin_controller_connection(self) -> LocalConnection: """Return a connection for accessing the host as a controller from the origin.""" return LocalConnection(self.args) @@ -1317,6 +1359,11 @@ class PosixRemoteProfile(ControllerHostProfile[PosixRemoteConfig], RemoteProfile class PosixSshProfile(SshTargetHostProfile[PosixSshConfig], PosixProfile[PosixSshConfig]): """Host profile for a POSIX SSH instance.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return self.config.host + def get_controller_target_connections(self) -> list[SshConnection]: """Return SSH connection(s) for accessing the host as a target from the controller.""" settings = SshConnectionDetail( @@ -1334,6 +1381,11 @@ class PosixSshProfile(SshTargetHostProfile[PosixSshConfig], PosixProfile[PosixSs class WindowsInventoryProfile(SshTargetHostProfile[WindowsInventoryConfig]): """Host profile for a Windows inventory.""" + @property + def name(self) -> str: + """The name of the host profile.""" + return self.config.path + def get_controller_target_connections(self) -> list[SshConnection]: """Return SSH connection(s) for accessing the host as a target from the controller.""" inventory = parse_inventory(self.args, self.config.path) @@ -1436,9 +1488,9 @@ def get_config_profile_type_map() -> dict[t.Type[HostConfig], t.Type[HostProfile def create_host_profile( args: EnvironmentConfig, config: HostConfig, - controller: bool, + controller: ControllerHostProfile | None, ) -> HostProfile: """Create and return a host profile from the given host configuration.""" profile_type = get_config_profile_type_map()[type(config)] - profile = profile_type(args=args, config=config, targets=args.targets if controller else None) + profile = profile_type(args=args, config=config, controller=controller) return profile diff --git a/test/lib/ansible_test/_internal/provisioning.py b/test/lib/ansible_test/_internal/provisioning.py index a174e808014..e7ff02f2ce6 100644 --- a/test/lib/ansible_test/_internal/provisioning.py +++ b/test/lib/ansible_test/_internal/provisioning.py @@ -116,9 +116,11 @@ def prepare_profiles( else: run_pypi_proxy(args, targets_use_pypi) + controller_host_profile = t.cast(ControllerHostProfile, create_host_profile(args, args.controller, None)) + host_state = HostState( - controller_profile=t.cast(ControllerHostProfile, create_host_profile(args, args.controller, True)), - target_profiles=[create_host_profile(args, target, False) for target in args.targets], + controller_profile=controller_host_profile, + target_profiles=[create_host_profile(args, target, controller_host_profile) for target in args.targets], ) if args.prime_containers: @@ -137,7 +139,9 @@ def prepare_profiles( if not skip_setup: profile.setup() - dispatch_jobs([(profile, WrappedThread(functools.partial(provision, profile))) for profile in host_state.profiles]) + dispatch_jobs( + [(profile, WrappedThread(functools.partial(provision, profile), f'Provision: {profile}')) for profile in host_state.profiles] + ) host_state.controller_profile.configure() @@ -157,7 +161,9 @@ def prepare_profiles( if requirements: requirements(profile) - dispatch_jobs([(profile, WrappedThread(functools.partial(configure, profile))) for profile in host_state.target_profiles]) + dispatch_jobs( + [(profile, WrappedThread(functools.partial(configure, profile), f'Configure: {profile}')) for profile in host_state.target_profiles] + ) return host_state diff --git a/test/lib/ansible_test/_internal/ssh.py b/test/lib/ansible_test/_internal/ssh.py index 80fc354b578..9660b2f676d 100644 --- a/test/lib/ansible_test/_internal/ssh.py +++ b/test/lib/ansible_test/_internal/ssh.py @@ -13,7 +13,6 @@ import shlex import typing as t from .encoding import ( - to_bytes, to_text, ) @@ -223,13 +222,10 @@ def run_ssh_command( cmd_show = shlex.join(cmd) display.info('Run background command: %s' % cmd_show, verbosity=1, truncate=True) - cmd_bytes = [to_bytes(arg) for arg in cmd] - env_bytes = dict((to_bytes(k), to_bytes(v)) for k, v in env.items()) - if args.explain: process = SshProcess(None) else: - process = SshProcess(subprocess.Popen(cmd_bytes, env=env_bytes, bufsize=-1, # pylint: disable=consider-using-with + process = SshProcess(subprocess.Popen(cmd, env=env, bufsize=-1, # pylint: disable=consider-using-with stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.PIPE)) return process diff --git a/test/lib/ansible_test/_internal/thread.py b/test/lib/ansible_test/_internal/thread.py index 9b712370bb9..515d2c73daa 100644 --- a/test/lib/ansible_test/_internal/thread.py +++ b/test/lib/ansible_test/_internal/thread.py @@ -17,11 +17,12 @@ TCallable = t.TypeVar('TCallable', bound=t.Callable[..., t.Any]) class WrappedThread(threading.Thread): """Wrapper around Thread which captures results and exceptions.""" - def __init__(self, action: c.Callable[[], t.Any]) -> None: + def __init__(self, action: c.Callable[[], t.Any], name: str) -> None: super().__init__() self._result: queue.Queue[t.Any] = queue.Queue() self.action = action self.result = None + self.name = name def run(self) -> None: """ diff --git a/test/lib/ansible_test/_internal/timeout.py b/test/lib/ansible_test/_internal/timeout.py index 898d0aae0ed..15fc9ea9d4e 100644 --- a/test/lib/ansible_test/_internal/timeout.py +++ b/test/lib/ansible_test/_internal/timeout.py @@ -126,6 +126,6 @@ def configure_test_timeout(args: TestConfig) -> None: signal.signal(signal.SIGUSR1, timeout_handler) - instance = WrappedThread(functools.partial(timeout_waiter, timeout_remaining.total_seconds())) + instance = WrappedThread(functools.partial(timeout_waiter, timeout_remaining.total_seconds()), 'Timeout Watchdog') instance.daemon = True instance.start() diff --git a/test/lib/ansible_test/_internal/util.py b/test/lib/ansible_test/_internal/util.py index 778eef24cdc..1da631a0b5f 100644 --- a/test/lib/ansible_test/_internal/util.py +++ b/test/lib/ansible_test/_internal/util.py @@ -533,16 +533,23 @@ def raw_command( try: try: - cmd_bytes = [to_bytes(arg) for arg in cmd] - env_bytes = dict((to_bytes(k), to_bytes(v)) for k, v in env.items()) - process = subprocess.Popen(cmd_bytes, env=env_bytes, stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd) # pylint: disable=consider-using-with + process = subprocess.Popen(cmd, env=env, stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd) # pylint: disable=consider-using-with except FileNotFoundError as ex: raise ApplicationError('Required program "%s" not found.' % cmd[0]) from ex if communicate: data_bytes = to_optional_bytes(data) - stdout_bytes, stderr_bytes = communicate_with_process(process, data_bytes, stdout == subprocess.PIPE, stderr == subprocess.PIPE, capture=capture, - output_stream=output_stream) + + stdout_bytes, stderr_bytes = communicate_with_process( + name=cmd[0], + process=process, + stdin=data_bytes, + stdout=stdout == subprocess.PIPE, + stderr=stderr == subprocess.PIPE, + capture=capture, + output_stream=output_stream, + ) + stdout_text = to_optional_text(stdout_bytes, str_errors) or '' stderr_text = to_optional_text(stderr_bytes, str_errors) or '' else: @@ -566,6 +573,7 @@ def raw_command( def communicate_with_process( + name: str, process: subprocess.Popen, stdin: t.Optional[bytes], stdout: bool, @@ -583,16 +591,16 @@ def communicate_with_process( reader = OutputThread if stdin is not None: - threads.append(WriterThread(process.stdin, stdin)) + threads.append(WriterThread(process.stdin, stdin, name)) if stdout: - stdout_reader = reader(process.stdout, output_stream.get_buffer(sys.stdout.buffer)) + stdout_reader = reader(process.stdout, output_stream.get_buffer(sys.stdout.buffer), name) threads.append(stdout_reader) else: stdout_reader = None if stderr: - stderr_reader = reader(process.stderr, output_stream.get_buffer(sys.stderr.buffer)) + stderr_reader = reader(process.stderr, output_stream.get_buffer(sys.stderr.buffer), name) threads.append(stderr_reader) else: stderr_reader = None @@ -624,8 +632,8 @@ def communicate_with_process( class WriterThread(WrappedThread): """Thread to write data to stdin of a subprocess.""" - def __init__(self, handle: t.IO[bytes], data: bytes) -> None: - super().__init__(self._run) + def __init__(self, handle: t.IO[bytes], data: bytes, name: str) -> None: + super().__init__(self._run, f'{self.__class__.__name__}: {name}') self.handle = handle self.data = data @@ -642,8 +650,8 @@ class WriterThread(WrappedThread): class ReaderThread(WrappedThread, metaclass=abc.ABCMeta): """Thread to read stdout from a subprocess.""" - def __init__(self, handle: t.IO[bytes], buffer: t.BinaryIO) -> None: - super().__init__(self._run) + def __init__(self, handle: t.IO[bytes], buffer: t.BinaryIO, name: str) -> None: + super().__init__(self._run, f'{self.__class__.__name__}: {name}') self.handle = handle self.buffer = buffer