ansible-test - Replace TypeVar usage (#85603)

pull/85613/head
Matt Clay 4 months ago committed by GitHub
parent 5083eaffc6
commit dc5209a3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,14 +3,11 @@
from __future__ import annotations
import collections.abc as c
import typing as t
from .config import (
CommonConfig,
)
TValue = t.TypeVar('TValue')
class CommonCache:
"""Common cache."""
@ -18,14 +15,14 @@ class CommonCache:
def __init__(self, args: CommonConfig) -> None:
self.args = args
def get(self, key: str, factory: c.Callable[[], TValue]) -> TValue:
def get[TValue](self, key: str, factory: c.Callable[[], TValue]) -> TValue:
"""Return the value from the cache identified by the given key, using the specified factory method if it is not found."""
if key not in self.args.cache:
self.args.cache[key] = factory()
return self.args.cache[key]
def get_with_args(self, key: str, factory: c.Callable[[CommonConfig], TValue]) -> TValue:
def get_with_args[TValue](self, key: str, factory: c.Callable[[CommonConfig], TValue]) -> TValue:
"""Return the value from the cache identified by the given key, using the specified factory method (which accepts args) if it is not found."""
if key not in self.args.cache:
self.args.cache[key] = factory(self.args)

@ -63,8 +63,6 @@ from . import (
PathChecker,
)
TValue = t.TypeVar('TValue')
def command_coverage_combine(args: CoverageCombineConfig) -> None:
"""Patch paths in coverage files and merge into a single file."""
@ -287,7 +285,7 @@ def _get_coverage_targets(args: CoverageCombineConfig, walk_func: c.Callable) ->
return sources
def _build_stub_groups(
def _build_stub_groups[TValue](
args: CoverageCombineConfig,
sources: list[tuple[str, int]],
default_stub_value: c.Callable[[list[str]], dict[str, TValue]],

@ -41,7 +41,6 @@ from ...target import (
walk_integration_targets,
IntegrationTarget,
walk_internal_targets,
TIntegrationTarget,
IntegrationTargetType,
)
@ -50,7 +49,6 @@ from ...config import (
NetworkIntegrationConfig,
PosixIntegrationConfig,
WindowsIntegrationConfig,
TIntegrationConfig,
)
from ...io import (
@ -132,8 +130,6 @@ from .coverage import (
CoverageManager,
)
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
def generate_dependency_map(integration_targets: list[IntegrationTarget]) -> dict[str, set[IntegrationTarget]]:
"""Analyze the given list of integration test targets and return a dictionary expressing target names and the targets on which they depend."""
@ -856,7 +852,7 @@ class IntegrationCache(CommonCache):
return self.get('dependency_map', lambda: generate_dependency_map(self.integration_targets))
def filter_profiles_for_target(args: IntegrationConfig, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
def filter_profiles_for_target[T: HostProfile](args: IntegrationConfig, profiles: list[T], target: IntegrationTarget) -> list[T]:
"""Return a list of profiles after applying target filters."""
if target.target_type == IntegrationTargetType.CONTROLLER:
profile_filter = get_target_filter(args, [args.controller], True)
@ -912,7 +908,7 @@ If necessary, context can be controlled by adding entries to the "aliases" file
return exclude
def command_integration_filter(
def command_integration_filter[TIntegrationTarget: IntegrationTarget, TIntegrationConfig: IntegrationConfig](
args: TIntegrationConfig,
targets: c.Iterable[TIntegrationTarget],
) -> tuple[HostState, tuple[TIntegrationTarget, ...]]:

@ -79,10 +79,8 @@ from ...inventory import (
create_posix_inventory,
)
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
class CoverageHandler(t.Generic[THostConfig], metaclass=abc.ABCMeta):
class CoverageHandler[THostConfig: HostConfig](metaclass=abc.ABCMeta):
"""Base class for configuring hosts for integration test code coverage."""
def __init__(self, args: IntegrationConfig, host_state: HostState, inventory_path: str) -> None:

@ -40,13 +40,8 @@ from ...host_profiles import (
HostProfile,
)
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
TPosixConfig = t.TypeVar('TPosixConfig', bound=PosixConfig)
TRemoteConfig = t.TypeVar('TRemoteConfig', bound=RemoteConfig)
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
class TargetFilter[THostConfig: HostConfig](metaclass=abc.ABCMeta):
"""Base class for target filters."""
def __init__(self, args: IntegrationConfig, configs: list[THostConfig], controller: bool) -> None:
@ -92,7 +87,7 @@ class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
exclude.update(skipped)
display.warning(f'Excluding {self.host_type} tests marked {marked} {reason}: {", ".join(skipped)}')
def filter_profiles(self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
def filter_profiles[THostProfile: HostProfile](self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
"""Filter the list of profiles, returning only those which are not skipped for the given target."""
del target
return profiles
@ -138,7 +133,7 @@ class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
self.skip('unstable', 'which require --allow-unstable or prefixing with "unstable/"', targets, exclude, override)
class PosixTargetFilter(TargetFilter[TPosixConfig]):
class PosixTargetFilter[TPosixConfig: PosixConfig](TargetFilter[TPosixConfig]):
"""Target filter for POSIX hosts."""
def filter_targets(self, targets: list[IntegrationTarget], exclude: set[str]) -> None:
@ -169,10 +164,10 @@ class PosixSshTargetFilter(PosixTargetFilter[PosixSshConfig]):
"""Target filter for POSIX SSH hosts."""
class RemoteTargetFilter(TargetFilter[TRemoteConfig]):
class RemoteTargetFilter[TRemoteConfig: RemoteConfig](TargetFilter[TRemoteConfig]):
"""Target filter for remote Ansible Core CI managed hosts."""
def filter_profiles(self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
def filter_profiles[THostProfile: HostProfile](self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
"""Filter the list of profiles, returning only those which are not skipped for the given target."""
profiles = super().filter_profiles(profiles, target)

@ -250,10 +250,7 @@ class WindowsRemoteCompletionConfig(RemoteCompletionConfig):
connection: str = ''
TCompletionConfig = t.TypeVar('TCompletionConfig', bound=CompletionConfig)
def load_completion(name: str, completion_type: t.Type[TCompletionConfig]) -> dict[str, TCompletionConfig]:
def load_completion[TCompletionConfig: CompletionConfig](name: str, completion_type: t.Type[TCompletionConfig]) -> dict[str, TCompletionConfig]:
"""Load the named completion entries, returning them in dictionary form using the specified completion type."""
lines = read_lines_without_comments(os.path.join(ANSIBLE_TEST_DATA_ROOT, 'completion', '%s.txt' % name), remove_blank_lines=True)
@ -283,7 +280,7 @@ def parse_completion_entry(value: str) -> tuple[str, dict[str, str]]:
return name, data
def filter_completion(
def filter_completion[TCompletionConfig: CompletionConfig](
completion: dict[str, TCompletionConfig],
controller_only: bool = False,
include_defaults: bool = False,

@ -38,8 +38,6 @@ from .host_configs import (
VirtualPythonConfig,
)
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
class TerminateMode(enum.Enum):
"""When to terminate instances."""
@ -166,7 +164,7 @@ class EnvironmentConfig(CommonConfig):
"""Host configuration for the targets."""
return self.host_settings.targets
def only_target(self, target_type: t.Type[THostConfig]) -> THostConfig:
def only_target[THostConfig: HostConfig](self, target_type: t.Type[THostConfig]) -> THostConfig:
"""
Return the host configuration for the target.
Requires that there is exactly one target of the specified type.
@ -183,7 +181,7 @@ class EnvironmentConfig(CommonConfig):
return target
def only_targets(self, target_type: t.Type[THostConfig]) -> list[THostConfig]:
def only_targets[THostConfig: HostConfig](self, target_type: t.Type[THostConfig]) -> list[THostConfig]:
"""
Return a list of target host configurations.
Requires that there are one or more targets, all the specified type.
@ -318,9 +316,6 @@ class IntegrationConfig(TestConfig):
return ansible_config_path
TIntegrationConfig = t.TypeVar('TIntegrationConfig', bound=IntegrationConfig)
class PosixIntegrationConfig(IntegrationConfig):
"""Configuration for the posix integration command."""

@ -144,11 +144,6 @@ from .debugging import (
DebuggerSettings,
)
TControllerHostConfig = t.TypeVar('TControllerHostConfig', bound=ControllerHostConfig)
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
TPosixConfig = t.TypeVar('TPosixConfig', bound=PosixConfig)
TRemoteConfig = t.TypeVar('TRemoteConfig', bound=RemoteConfig)
class ControlGroupError(ApplicationError):
"""Raised when the container host does not have the necessary cgroup support to run a container."""
@ -239,7 +234,7 @@ class Inventory:
display.info(f'>>> Inventory\n{inventory_text}', verbosity=3)
class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta):
class HostProfile[THostConfig: HostConfig](metaclass=abc.ABCMeta):
"""Base class for host profiles."""
def __init__(
@ -296,7 +291,7 @@ class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta):
return f'{self.__class__.__name__}: {self.name}'
class DebuggableProfile(HostProfile[THostConfig], DebuggerProfile, metaclass=abc.ABCMeta):
class DebuggableProfile[THostConfig: HostConfig](HostProfile[THostConfig], DebuggerProfile, metaclass=abc.ABCMeta):
"""Base class for profiles remote debugging."""
__DEBUGGING_PORT_KEY = 'debugging_port'
@ -462,7 +457,7 @@ class DebuggableProfile(HostProfile[THostConfig], DebuggerProfile, metaclass=abc
)
class PosixProfile(HostProfile[TPosixConfig], metaclass=abc.ABCMeta):
class PosixProfile[TPosixConfig: PosixConfig](HostProfile[TPosixConfig], metaclass=abc.ABCMeta):
"""Base class for POSIX host profiles."""
@property
@ -484,7 +479,7 @@ class PosixProfile(HostProfile[TPosixConfig], metaclass=abc.ABCMeta):
return python
class ControllerHostProfile(PosixProfile[TControllerHostConfig], DebuggableProfile[TControllerHostConfig], metaclass=abc.ABCMeta):
class ControllerHostProfile[T: ControllerHostConfig](PosixProfile[T], DebuggableProfile[T], metaclass=abc.ABCMeta):
"""Base class for profiles usable as a controller."""
@abc.abstractmethod
@ -496,7 +491,7 @@ class ControllerHostProfile(PosixProfile[TControllerHostConfig], DebuggableProfi
"""Return the working directory for the host."""
class SshTargetHostProfile(HostProfile[THostConfig], metaclass=abc.ABCMeta):
class SshTargetHostProfile[THostConfig: HostConfig](HostProfile[THostConfig], metaclass=abc.ABCMeta):
"""Base class for profiles offering SSH connectivity."""
@abc.abstractmethod
@ -504,7 +499,7 @@ class SshTargetHostProfile(HostProfile[THostConfig], metaclass=abc.ABCMeta):
"""Return SSH connection(s) for accessing the host as a target from the controller."""
class RemoteProfile(SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta):
class RemoteProfile[TRemoteConfig: RemoteConfig](SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta):
"""Base class for remote instance profiles."""
@property

@ -12,12 +12,12 @@ from ..util import (
)
def get_path_provider_classes(provider_type: t.Type[TPathProvider]) -> list[t.Type[TPathProvider]]:
def get_path_provider_classes[TPathProvider: PathProvider](provider_type: t.Type[TPathProvider]) -> list[t.Type[TPathProvider]]:
"""Return a list of path provider classes of the given type."""
return sorted(get_subclasses(provider_type), key=lambda subclass: (subclass.priority, subclass.__name__))
def find_path_provider(
def find_path_provider[TPathProvider: PathProvider](
provider_type: t.Type[TPathProvider],
provider_classes: list[t.Type[TPathProvider]],
path: str,
@ -71,6 +71,3 @@ class PathProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod
def is_content_root(path: str) -> bool:
"""Return True if the given path is a content root for this provider."""
TPathProvider = t.TypeVar('TPathProvider', bound=PathProvider)

@ -48,9 +48,6 @@ from .pypi_proxy import (
run_pypi_proxy,
)
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
TEnvironmentConfig = t.TypeVar('TEnvironmentConfig', bound=EnvironmentConfig)
class PrimeContainers(ApplicationError):
"""Exception raised to end execution early after priming containers."""
@ -91,7 +88,7 @@ class HostState:
return list(itertools.chain.from_iterable([target.get_controller_target_connections() for
target in self.target_profiles if isinstance(target, SshTargetHostProfile)]))
def targets(self, profile_type: t.Type[THostProfile]) -> list[THostProfile]:
def targets[THostProfile: HostProfile](self, profile_type: t.Type[THostProfile]) -> list[THostProfile]:
"""The list of target(s), verified to be of the specified type."""
if not self.target_profiles:
raise Exception('No target profiles found.')
@ -101,7 +98,7 @@ class HostState:
return t.cast(list[THostProfile], self.target_profiles)
def prepare_profiles(
def prepare_profiles[TEnvironmentConfig: EnvironmentConfig](
args: TEnvironmentConfig,
targets_use_pypi: bool = False,
skip_setup: bool = False,

@ -65,7 +65,7 @@ def walk_completion_targets(targets: c.Iterable[CompletionTarget], prefix: str,
return tuple(sorted(matches))
def walk_internal_targets(
def walk_internal_targets[TCompletionTarget: CompletionTarget](
targets: c.Iterable[TCompletionTarget],
includes: t.Optional[list[str]] = None,
excludes: t.Optional[list[str]] = None,
@ -87,7 +87,7 @@ def walk_internal_targets(
return tuple(sorted(internal_targets, key=lambda sort_target: sort_target.name))
def filter_targets(
def filter_targets[TCompletionTarget: CompletionTarget](
targets: c.Iterable[TCompletionTarget],
patterns: list[str],
include: bool = True,
@ -711,7 +711,3 @@ class TargetPatternsNotMatched(ApplicationError):
message = 'Target pattern not matched: %s' % self.patterns[0]
super().__init__(message)
TCompletionTarget = t.TypeVar('TCompletionTarget', bound=CompletionTarget)
TIntegrationTarget = t.TypeVar('TIntegrationTarget', bound=IntegrationTarget)

@ -11,9 +11,6 @@ import queue
import typing as t
TCallable = t.TypeVar('TCallable', bound=t.Callable[..., t.Any])
class WrappedThread(threading.Thread):
"""Wrapper around Thread which captures results and exceptions."""
@ -50,7 +47,7 @@ class WrappedThread(threading.Thread):
return result
def mutex(func: TCallable) -> TCallable:
def mutex[TCallable: t.Callable[..., t.Any]](func: TCallable) -> TCallable:
"""Enforce exclusive access on a decorated function."""
lock = threading.Lock()

@ -57,11 +57,6 @@ from .constants import (
SUPPORTED_PYTHON_VERSIONS,
)
C = t.TypeVar('C')
TBase = t.TypeVar('TBase')
TKey = t.TypeVar('TKey')
TValue = t.TypeVar('TValue')
PYTHON_PATHS: dict[str, str] = {}
COVERAGE_CONFIG_NAME = 'coveragerc'
@ -180,7 +175,7 @@ def is_valid_identifier(value: str) -> bool:
return value.isidentifier() and not keyword.iskeyword(value)
def cache(func: c.Callable[[], TValue]) -> c.Callable[[], TValue]:
def cache[TValue](func: c.Callable[[], TValue]) -> c.Callable[[], TValue]:
"""Enforce exclusive access on a decorated function and cache the result."""
storage: dict[None, TValue] = {}
sentinel = object()
@ -313,7 +308,7 @@ def read_lines_without_comments(path: str, remove_blank_lines: bool = False, opt
return lines
def exclude_none_values(data: dict[TKey, t.Optional[TValue]]) -> dict[TKey, TValue]:
def exclude_none_values[TKey, TValue](data: dict[TKey, t.Optional[TValue]]) -> dict[TKey, TValue]:
"""Return the provided dictionary with any None values excluded."""
return dict((key, value) for key, value in data.items() if value is not None)
@ -1058,7 +1053,7 @@ def format_command_output(stdout: str | None, stderr: str | None) -> str:
return message
def retry(func: t.Callable[..., TValue], ex_type: t.Type[BaseException] = SubprocessError, sleep: int = 10, attempts: int = 10, warn: bool = True) -> TValue:
def retry[T](func: t.Callable[..., T], ex_type: t.Type[BaseException] = SubprocessError, sleep: int = 10, attempts: int = 10, warn: bool = True) -> T:
"""Retry the specified function on failure."""
for dummy in range(1, attempts):
try:
@ -1091,7 +1086,7 @@ def parse_to_list_of_dict(pattern: str, value: str) -> list[dict[str, str]]:
return matched
def get_subclasses(class_type: t.Type[C]) -> list[t.Type[C]]:
def get_subclasses[C](class_type: t.Type[C]) -> list[t.Type[C]]:
"""Returns a list of types that are concrete subclasses of the given type."""
subclasses: set[t.Type[C]] = set()
queue: list[t.Type[C]] = [class_type]
@ -1167,7 +1162,7 @@ def import_plugins(directory: str, root: t.Optional[str] = None) -> None:
load_module(module_path, name)
def load_plugins(base_type: t.Type[C], database: dict[str, t.Type[C]]) -> None:
def load_plugins[C](base_type: t.Type[C], database: dict[str, t.Type[C]]) -> None:
"""
Load plugins of the specified type and track them in the specified database.
Only plugins which have already been imported will be loaded.
@ -1194,19 +1189,19 @@ def sanitize_host_name(name: str) -> str:
return re.sub('[^A-Za-z0-9]+', '-', name)[:63].strip('-')
def get_generic_type(base_type: t.Type, generic_base_type: t.Type[TValue]) -> t.Optional[t.Type[TValue]]:
def get_generic_type[TValue](base_type: t.Type, generic_base_type: t.Type[TValue]) -> t.Optional[t.Type[TValue]]:
"""Return the generic type arg derived from the generic_base_type type that is associated with the base_type type, if any, otherwise return None."""
# noinspection PyUnresolvedReferences
type_arg = t.get_args(base_type.__orig_bases__[0])[0]
return None if isinstance(type_arg, generic_base_type) else type_arg
def get_type_associations(base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> list[tuple[t.Type[TValue], t.Type[TBase]]]:
def get_type_associations[TBase, TValue](base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> list[tuple[t.Type[TValue], t.Type[TBase]]]:
"""Create and return a list of tuples associating generic_base_type derived types with a corresponding base_type derived type."""
return [item for item in [(get_generic_type(sc_type, generic_base_type), sc_type) for sc_type in get_subclasses(base_type)] if item[1]]
def get_type_map(base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> dict[t.Type[TValue], t.Type[TBase]]:
def get_type_map[TBase, TValue](base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> dict[t.Type[TValue], t.Type[TBase]]:
"""Create and return a mapping of generic_base_type derived types to base_type derived types."""
return {item[0]: item[1] for item in get_type_associations(base_type, generic_base_type)}
@ -1227,7 +1222,7 @@ def verify_sys_executable(path: str) -> t.Optional[str]:
return expected_executable
def type_guard(sequence: c.Sequence[t.Any], guard_type: t.Type[C]) -> t.TypeGuard[c.Sequence[C]]:
def type_guard[C](sequence: c.Sequence[t.Any], guard_type: t.Type[C]) -> t.TypeGuard[c.Sequence[C]]:
"""
Raises an exception if any item in the given sequence does not match the specified guard type.
Use with assert so that type checkers are aware of the type guard.

@ -6,6 +6,7 @@
from __future__ import annotations
import importlib
import os
import sys
@ -29,8 +30,9 @@ def main(args=None):
if any(not os.get_blocking(handle.fileno()) for handle in (sys.stdin, sys.stdout, sys.stderr)):
raise SystemExit('Standard input, output and error file handles must be blocking to run ansible-test.')
# noinspection PyProtectedMember
from ansible_test._internal import main as cli_main
# avoid using import to hide it from mypy
internal = importlib.import_module('ansible_test._internal')
cli_main = getattr(internal, 'main')
cli_main(args)

Loading…
Cancel
Save