ansible-test - Replace TypeVar usage (#85603)

pull/85613/head
Matt Clay 4 months ago committed by GitHub
parent 5083eaffc6
commit dc5209a3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -3,14 +3,11 @@
from __future__ import annotations from __future__ import annotations
import collections.abc as c import collections.abc as c
import typing as t
from .config import ( from .config import (
CommonConfig, CommonConfig,
) )
TValue = t.TypeVar('TValue')
class CommonCache: class CommonCache:
"""Common cache.""" """Common cache."""
@ -18,14 +15,14 @@ class CommonCache:
def __init__(self, args: CommonConfig) -> None: def __init__(self, args: CommonConfig) -> None:
self.args = args self.args = args
def get(self, key: str, factory: c.Callable[[], TValue]) -> TValue: def get[TValue](self, key: str, factory: c.Callable[[], TValue]) -> TValue:
"""Return the value from the cache identified by the given key, using the specified factory method if it is not found.""" """Return the value from the cache identified by the given key, using the specified factory method if it is not found."""
if key not in self.args.cache: if key not in self.args.cache:
self.args.cache[key] = factory() self.args.cache[key] = factory()
return self.args.cache[key] return self.args.cache[key]
def get_with_args(self, key: str, factory: c.Callable[[CommonConfig], TValue]) -> TValue: def get_with_args[TValue](self, key: str, factory: c.Callable[[CommonConfig], TValue]) -> TValue:
"""Return the value from the cache identified by the given key, using the specified factory method (which accepts args) if it is not found.""" """Return the value from the cache identified by the given key, using the specified factory method (which accepts args) if it is not found."""
if key not in self.args.cache: if key not in self.args.cache:
self.args.cache[key] = factory(self.args) self.args.cache[key] = factory(self.args)

@ -63,8 +63,6 @@ from . import (
PathChecker, PathChecker,
) )
TValue = t.TypeVar('TValue')
def command_coverage_combine(args: CoverageCombineConfig) -> None: def command_coverage_combine(args: CoverageCombineConfig) -> None:
"""Patch paths in coverage files and merge into a single file.""" """Patch paths in coverage files and merge into a single file."""
@ -287,7 +285,7 @@ def _get_coverage_targets(args: CoverageCombineConfig, walk_func: c.Callable) ->
return sources return sources
def _build_stub_groups( def _build_stub_groups[TValue](
args: CoverageCombineConfig, args: CoverageCombineConfig,
sources: list[tuple[str, int]], sources: list[tuple[str, int]],
default_stub_value: c.Callable[[list[str]], dict[str, TValue]], default_stub_value: c.Callable[[list[str]], dict[str, TValue]],

@ -41,7 +41,6 @@ from ...target import (
walk_integration_targets, walk_integration_targets,
IntegrationTarget, IntegrationTarget,
walk_internal_targets, walk_internal_targets,
TIntegrationTarget,
IntegrationTargetType, IntegrationTargetType,
) )
@ -50,7 +49,6 @@ from ...config import (
NetworkIntegrationConfig, NetworkIntegrationConfig,
PosixIntegrationConfig, PosixIntegrationConfig,
WindowsIntegrationConfig, WindowsIntegrationConfig,
TIntegrationConfig,
) )
from ...io import ( from ...io import (
@ -132,8 +130,6 @@ from .coverage import (
CoverageManager, CoverageManager,
) )
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
def generate_dependency_map(integration_targets: list[IntegrationTarget]) -> dict[str, set[IntegrationTarget]]: def generate_dependency_map(integration_targets: list[IntegrationTarget]) -> dict[str, set[IntegrationTarget]]:
"""Analyze the given list of integration test targets and return a dictionary expressing target names and the targets on which they depend.""" """Analyze the given list of integration test targets and return a dictionary expressing target names and the targets on which they depend."""
@ -856,7 +852,7 @@ class IntegrationCache(CommonCache):
return self.get('dependency_map', lambda: generate_dependency_map(self.integration_targets)) return self.get('dependency_map', lambda: generate_dependency_map(self.integration_targets))
def filter_profiles_for_target(args: IntegrationConfig, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]: def filter_profiles_for_target[T: HostProfile](args: IntegrationConfig, profiles: list[T], target: IntegrationTarget) -> list[T]:
"""Return a list of profiles after applying target filters.""" """Return a list of profiles after applying target filters."""
if target.target_type == IntegrationTargetType.CONTROLLER: if target.target_type == IntegrationTargetType.CONTROLLER:
profile_filter = get_target_filter(args, [args.controller], True) profile_filter = get_target_filter(args, [args.controller], True)
@ -912,7 +908,7 @@ If necessary, context can be controlled by adding entries to the "aliases" file
return exclude return exclude
def command_integration_filter( def command_integration_filter[TIntegrationTarget: IntegrationTarget, TIntegrationConfig: IntegrationConfig](
args: TIntegrationConfig, args: TIntegrationConfig,
targets: c.Iterable[TIntegrationTarget], targets: c.Iterable[TIntegrationTarget],
) -> tuple[HostState, tuple[TIntegrationTarget, ...]]: ) -> tuple[HostState, tuple[TIntegrationTarget, ...]]:

@ -79,10 +79,8 @@ from ...inventory import (
create_posix_inventory, create_posix_inventory,
) )
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
class CoverageHandler[THostConfig: HostConfig](metaclass=abc.ABCMeta):
class CoverageHandler(t.Generic[THostConfig], metaclass=abc.ABCMeta):
"""Base class for configuring hosts for integration test code coverage.""" """Base class for configuring hosts for integration test code coverage."""
def __init__(self, args: IntegrationConfig, host_state: HostState, inventory_path: str) -> None: def __init__(self, args: IntegrationConfig, host_state: HostState, inventory_path: str) -> None:

@ -40,13 +40,8 @@ from ...host_profiles import (
HostProfile, HostProfile,
) )
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
TPosixConfig = t.TypeVar('TPosixConfig', bound=PosixConfig)
TRemoteConfig = t.TypeVar('TRemoteConfig', bound=RemoteConfig)
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
class TargetFilter[THostConfig: HostConfig](metaclass=abc.ABCMeta):
class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
"""Base class for target filters.""" """Base class for target filters."""
def __init__(self, args: IntegrationConfig, configs: list[THostConfig], controller: bool) -> None: def __init__(self, args: IntegrationConfig, configs: list[THostConfig], controller: bool) -> None:
@ -92,7 +87,7 @@ class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
exclude.update(skipped) exclude.update(skipped)
display.warning(f'Excluding {self.host_type} tests marked {marked} {reason}: {", ".join(skipped)}') display.warning(f'Excluding {self.host_type} tests marked {marked} {reason}: {", ".join(skipped)}')
def filter_profiles(self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]: def filter_profiles[THostProfile: HostProfile](self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
"""Filter the list of profiles, returning only those which are not skipped for the given target.""" """Filter the list of profiles, returning only those which are not skipped for the given target."""
del target del target
return profiles return profiles
@ -138,7 +133,7 @@ class TargetFilter(t.Generic[THostConfig], metaclass=abc.ABCMeta):
self.skip('unstable', 'which require --allow-unstable or prefixing with "unstable/"', targets, exclude, override) self.skip('unstable', 'which require --allow-unstable or prefixing with "unstable/"', targets, exclude, override)
class PosixTargetFilter(TargetFilter[TPosixConfig]): class PosixTargetFilter[TPosixConfig: PosixConfig](TargetFilter[TPosixConfig]):
"""Target filter for POSIX hosts.""" """Target filter for POSIX hosts."""
def filter_targets(self, targets: list[IntegrationTarget], exclude: set[str]) -> None: def filter_targets(self, targets: list[IntegrationTarget], exclude: set[str]) -> None:
@ -169,10 +164,10 @@ class PosixSshTargetFilter(PosixTargetFilter[PosixSshConfig]):
"""Target filter for POSIX SSH hosts.""" """Target filter for POSIX SSH hosts."""
class RemoteTargetFilter(TargetFilter[TRemoteConfig]): class RemoteTargetFilter[TRemoteConfig: RemoteConfig](TargetFilter[TRemoteConfig]):
"""Target filter for remote Ansible Core CI managed hosts.""" """Target filter for remote Ansible Core CI managed hosts."""
def filter_profiles(self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]: def filter_profiles[THostProfile: HostProfile](self, profiles: list[THostProfile], target: IntegrationTarget) -> list[THostProfile]:
"""Filter the list of profiles, returning only those which are not skipped for the given target.""" """Filter the list of profiles, returning only those which are not skipped for the given target."""
profiles = super().filter_profiles(profiles, target) profiles = super().filter_profiles(profiles, target)

@ -250,10 +250,7 @@ class WindowsRemoteCompletionConfig(RemoteCompletionConfig):
connection: str = '' connection: str = ''
TCompletionConfig = t.TypeVar('TCompletionConfig', bound=CompletionConfig) def load_completion[TCompletionConfig: CompletionConfig](name: str, completion_type: t.Type[TCompletionConfig]) -> dict[str, TCompletionConfig]:
def load_completion(name: str, completion_type: t.Type[TCompletionConfig]) -> dict[str, TCompletionConfig]:
"""Load the named completion entries, returning them in dictionary form using the specified completion type.""" """Load the named completion entries, returning them in dictionary form using the specified completion type."""
lines = read_lines_without_comments(os.path.join(ANSIBLE_TEST_DATA_ROOT, 'completion', '%s.txt' % name), remove_blank_lines=True) lines = read_lines_without_comments(os.path.join(ANSIBLE_TEST_DATA_ROOT, 'completion', '%s.txt' % name), remove_blank_lines=True)
@ -283,7 +280,7 @@ def parse_completion_entry(value: str) -> tuple[str, dict[str, str]]:
return name, data return name, data
def filter_completion( def filter_completion[TCompletionConfig: CompletionConfig](
completion: dict[str, TCompletionConfig], completion: dict[str, TCompletionConfig],
controller_only: bool = False, controller_only: bool = False,
include_defaults: bool = False, include_defaults: bool = False,

@ -38,8 +38,6 @@ from .host_configs import (
VirtualPythonConfig, VirtualPythonConfig,
) )
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
class TerminateMode(enum.Enum): class TerminateMode(enum.Enum):
"""When to terminate instances.""" """When to terminate instances."""
@ -166,7 +164,7 @@ class EnvironmentConfig(CommonConfig):
"""Host configuration for the targets.""" """Host configuration for the targets."""
return self.host_settings.targets return self.host_settings.targets
def only_target(self, target_type: t.Type[THostConfig]) -> THostConfig: def only_target[THostConfig: HostConfig](self, target_type: t.Type[THostConfig]) -> THostConfig:
""" """
Return the host configuration for the target. Return the host configuration for the target.
Requires that there is exactly one target of the specified type. Requires that there is exactly one target of the specified type.
@ -183,7 +181,7 @@ class EnvironmentConfig(CommonConfig):
return target return target
def only_targets(self, target_type: t.Type[THostConfig]) -> list[THostConfig]: def only_targets[THostConfig: HostConfig](self, target_type: t.Type[THostConfig]) -> list[THostConfig]:
""" """
Return a list of target host configurations. Return a list of target host configurations.
Requires that there are one or more targets, all the specified type. Requires that there are one or more targets, all the specified type.
@ -318,9 +316,6 @@ class IntegrationConfig(TestConfig):
return ansible_config_path return ansible_config_path
TIntegrationConfig = t.TypeVar('TIntegrationConfig', bound=IntegrationConfig)
class PosixIntegrationConfig(IntegrationConfig): class PosixIntegrationConfig(IntegrationConfig):
"""Configuration for the posix integration command.""" """Configuration for the posix integration command."""

@ -144,11 +144,6 @@ from .debugging import (
DebuggerSettings, DebuggerSettings,
) )
TControllerHostConfig = t.TypeVar('TControllerHostConfig', bound=ControllerHostConfig)
THostConfig = t.TypeVar('THostConfig', bound=HostConfig)
TPosixConfig = t.TypeVar('TPosixConfig', bound=PosixConfig)
TRemoteConfig = t.TypeVar('TRemoteConfig', bound=RemoteConfig)
class ControlGroupError(ApplicationError): class ControlGroupError(ApplicationError):
"""Raised when the container host does not have the necessary cgroup support to run a container.""" """Raised when the container host does not have the necessary cgroup support to run a container."""
@ -239,7 +234,7 @@ class Inventory:
display.info(f'>>> Inventory\n{inventory_text}', verbosity=3) display.info(f'>>> Inventory\n{inventory_text}', verbosity=3)
class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta): class HostProfile[THostConfig: HostConfig](metaclass=abc.ABCMeta):
"""Base class for host profiles.""" """Base class for host profiles."""
def __init__( def __init__(
@ -296,7 +291,7 @@ class HostProfile(t.Generic[THostConfig], metaclass=abc.ABCMeta):
return f'{self.__class__.__name__}: {self.name}' return f'{self.__class__.__name__}: {self.name}'
class DebuggableProfile(HostProfile[THostConfig], DebuggerProfile, metaclass=abc.ABCMeta): class DebuggableProfile[THostConfig: HostConfig](HostProfile[THostConfig], DebuggerProfile, metaclass=abc.ABCMeta):
"""Base class for profiles remote debugging.""" """Base class for profiles remote debugging."""
__DEBUGGING_PORT_KEY = 'debugging_port' __DEBUGGING_PORT_KEY = 'debugging_port'
@ -462,7 +457,7 @@ class DebuggableProfile(HostProfile[THostConfig], DebuggerProfile, metaclass=abc
) )
class PosixProfile(HostProfile[TPosixConfig], metaclass=abc.ABCMeta): class PosixProfile[TPosixConfig: PosixConfig](HostProfile[TPosixConfig], metaclass=abc.ABCMeta):
"""Base class for POSIX host profiles.""" """Base class for POSIX host profiles."""
@property @property
@ -484,7 +479,7 @@ class PosixProfile(HostProfile[TPosixConfig], metaclass=abc.ABCMeta):
return python return python
class ControllerHostProfile(PosixProfile[TControllerHostConfig], DebuggableProfile[TControllerHostConfig], metaclass=abc.ABCMeta): class ControllerHostProfile[T: ControllerHostConfig](PosixProfile[T], DebuggableProfile[T], metaclass=abc.ABCMeta):
"""Base class for profiles usable as a controller.""" """Base class for profiles usable as a controller."""
@abc.abstractmethod @abc.abstractmethod
@ -496,7 +491,7 @@ class ControllerHostProfile(PosixProfile[TControllerHostConfig], DebuggableProfi
"""Return the working directory for the host.""" """Return the working directory for the host."""
class SshTargetHostProfile(HostProfile[THostConfig], metaclass=abc.ABCMeta): class SshTargetHostProfile[THostConfig: HostConfig](HostProfile[THostConfig], metaclass=abc.ABCMeta):
"""Base class for profiles offering SSH connectivity.""" """Base class for profiles offering SSH connectivity."""
@abc.abstractmethod @abc.abstractmethod
@ -504,7 +499,7 @@ class SshTargetHostProfile(HostProfile[THostConfig], metaclass=abc.ABCMeta):
"""Return SSH connection(s) for accessing the host as a target from the controller.""" """Return SSH connection(s) for accessing the host as a target from the controller."""
class RemoteProfile(SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta): class RemoteProfile[TRemoteConfig: RemoteConfig](SshTargetHostProfile[TRemoteConfig], metaclass=abc.ABCMeta):
"""Base class for remote instance profiles.""" """Base class for remote instance profiles."""
@property @property

@ -12,12 +12,12 @@ from ..util import (
) )
def get_path_provider_classes(provider_type: t.Type[TPathProvider]) -> list[t.Type[TPathProvider]]: def get_path_provider_classes[TPathProvider: PathProvider](provider_type: t.Type[TPathProvider]) -> list[t.Type[TPathProvider]]:
"""Return a list of path provider classes of the given type.""" """Return a list of path provider classes of the given type."""
return sorted(get_subclasses(provider_type), key=lambda subclass: (subclass.priority, subclass.__name__)) return sorted(get_subclasses(provider_type), key=lambda subclass: (subclass.priority, subclass.__name__))
def find_path_provider( def find_path_provider[TPathProvider: PathProvider](
provider_type: t.Type[TPathProvider], provider_type: t.Type[TPathProvider],
provider_classes: list[t.Type[TPathProvider]], provider_classes: list[t.Type[TPathProvider]],
path: str, path: str,
@ -71,6 +71,3 @@ class PathProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def is_content_root(path: str) -> bool: def is_content_root(path: str) -> bool:
"""Return True if the given path is a content root for this provider.""" """Return True if the given path is a content root for this provider."""
TPathProvider = t.TypeVar('TPathProvider', bound=PathProvider)

@ -48,9 +48,6 @@ from .pypi_proxy import (
run_pypi_proxy, run_pypi_proxy,
) )
THostProfile = t.TypeVar('THostProfile', bound=HostProfile)
TEnvironmentConfig = t.TypeVar('TEnvironmentConfig', bound=EnvironmentConfig)
class PrimeContainers(ApplicationError): class PrimeContainers(ApplicationError):
"""Exception raised to end execution early after priming containers.""" """Exception raised to end execution early after priming containers."""
@ -91,7 +88,7 @@ class HostState:
return list(itertools.chain.from_iterable([target.get_controller_target_connections() for return list(itertools.chain.from_iterable([target.get_controller_target_connections() for
target in self.target_profiles if isinstance(target, SshTargetHostProfile)])) target in self.target_profiles if isinstance(target, SshTargetHostProfile)]))
def targets(self, profile_type: t.Type[THostProfile]) -> list[THostProfile]: def targets[THostProfile: HostProfile](self, profile_type: t.Type[THostProfile]) -> list[THostProfile]:
"""The list of target(s), verified to be of the specified type.""" """The list of target(s), verified to be of the specified type."""
if not self.target_profiles: if not self.target_profiles:
raise Exception('No target profiles found.') raise Exception('No target profiles found.')
@ -101,7 +98,7 @@ class HostState:
return t.cast(list[THostProfile], self.target_profiles) return t.cast(list[THostProfile], self.target_profiles)
def prepare_profiles( def prepare_profiles[TEnvironmentConfig: EnvironmentConfig](
args: TEnvironmentConfig, args: TEnvironmentConfig,
targets_use_pypi: bool = False, targets_use_pypi: bool = False,
skip_setup: bool = False, skip_setup: bool = False,

@ -65,7 +65,7 @@ def walk_completion_targets(targets: c.Iterable[CompletionTarget], prefix: str,
return tuple(sorted(matches)) return tuple(sorted(matches))
def walk_internal_targets( def walk_internal_targets[TCompletionTarget: CompletionTarget](
targets: c.Iterable[TCompletionTarget], targets: c.Iterable[TCompletionTarget],
includes: t.Optional[list[str]] = None, includes: t.Optional[list[str]] = None,
excludes: t.Optional[list[str]] = None, excludes: t.Optional[list[str]] = None,
@ -87,7 +87,7 @@ def walk_internal_targets(
return tuple(sorted(internal_targets, key=lambda sort_target: sort_target.name)) return tuple(sorted(internal_targets, key=lambda sort_target: sort_target.name))
def filter_targets( def filter_targets[TCompletionTarget: CompletionTarget](
targets: c.Iterable[TCompletionTarget], targets: c.Iterable[TCompletionTarget],
patterns: list[str], patterns: list[str],
include: bool = True, include: bool = True,
@ -711,7 +711,3 @@ class TargetPatternsNotMatched(ApplicationError):
message = 'Target pattern not matched: %s' % self.patterns[0] message = 'Target pattern not matched: %s' % self.patterns[0]
super().__init__(message) super().__init__(message)
TCompletionTarget = t.TypeVar('TCompletionTarget', bound=CompletionTarget)
TIntegrationTarget = t.TypeVar('TIntegrationTarget', bound=IntegrationTarget)

@ -11,9 +11,6 @@ import queue
import typing as t import typing as t
TCallable = t.TypeVar('TCallable', bound=t.Callable[..., t.Any])
class WrappedThread(threading.Thread): class WrappedThread(threading.Thread):
"""Wrapper around Thread which captures results and exceptions.""" """Wrapper around Thread which captures results and exceptions."""
@ -50,7 +47,7 @@ class WrappedThread(threading.Thread):
return result return result
def mutex(func: TCallable) -> TCallable: def mutex[TCallable: t.Callable[..., t.Any]](func: TCallable) -> TCallable:
"""Enforce exclusive access on a decorated function.""" """Enforce exclusive access on a decorated function."""
lock = threading.Lock() lock = threading.Lock()

@ -57,11 +57,6 @@ from .constants import (
SUPPORTED_PYTHON_VERSIONS, SUPPORTED_PYTHON_VERSIONS,
) )
C = t.TypeVar('C')
TBase = t.TypeVar('TBase')
TKey = t.TypeVar('TKey')
TValue = t.TypeVar('TValue')
PYTHON_PATHS: dict[str, str] = {} PYTHON_PATHS: dict[str, str] = {}
COVERAGE_CONFIG_NAME = 'coveragerc' COVERAGE_CONFIG_NAME = 'coveragerc'
@ -180,7 +175,7 @@ def is_valid_identifier(value: str) -> bool:
return value.isidentifier() and not keyword.iskeyword(value) return value.isidentifier() and not keyword.iskeyword(value)
def cache(func: c.Callable[[], TValue]) -> c.Callable[[], TValue]: def cache[TValue](func: c.Callable[[], TValue]) -> c.Callable[[], TValue]:
"""Enforce exclusive access on a decorated function and cache the result.""" """Enforce exclusive access on a decorated function and cache the result."""
storage: dict[None, TValue] = {} storage: dict[None, TValue] = {}
sentinel = object() sentinel = object()
@ -313,7 +308,7 @@ def read_lines_without_comments(path: str, remove_blank_lines: bool = False, opt
return lines return lines
def exclude_none_values(data: dict[TKey, t.Optional[TValue]]) -> dict[TKey, TValue]: def exclude_none_values[TKey, TValue](data: dict[TKey, t.Optional[TValue]]) -> dict[TKey, TValue]:
"""Return the provided dictionary with any None values excluded.""" """Return the provided dictionary with any None values excluded."""
return dict((key, value) for key, value in data.items() if value is not None) return dict((key, value) for key, value in data.items() if value is not None)
@ -1058,7 +1053,7 @@ def format_command_output(stdout: str | None, stderr: str | None) -> str:
return message return message
def retry(func: t.Callable[..., TValue], ex_type: t.Type[BaseException] = SubprocessError, sleep: int = 10, attempts: int = 10, warn: bool = True) -> TValue: def retry[T](func: t.Callable[..., T], ex_type: t.Type[BaseException] = SubprocessError, sleep: int = 10, attempts: int = 10, warn: bool = True) -> T:
"""Retry the specified function on failure.""" """Retry the specified function on failure."""
for dummy in range(1, attempts): for dummy in range(1, attempts):
try: try:
@ -1091,7 +1086,7 @@ def parse_to_list_of_dict(pattern: str, value: str) -> list[dict[str, str]]:
return matched return matched
def get_subclasses(class_type: t.Type[C]) -> list[t.Type[C]]: def get_subclasses[C](class_type: t.Type[C]) -> list[t.Type[C]]:
"""Returns a list of types that are concrete subclasses of the given type.""" """Returns a list of types that are concrete subclasses of the given type."""
subclasses: set[t.Type[C]] = set() subclasses: set[t.Type[C]] = set()
queue: list[t.Type[C]] = [class_type] queue: list[t.Type[C]] = [class_type]
@ -1167,7 +1162,7 @@ def import_plugins(directory: str, root: t.Optional[str] = None) -> None:
load_module(module_path, name) load_module(module_path, name)
def load_plugins(base_type: t.Type[C], database: dict[str, t.Type[C]]) -> None: def load_plugins[C](base_type: t.Type[C], database: dict[str, t.Type[C]]) -> None:
""" """
Load plugins of the specified type and track them in the specified database. Load plugins of the specified type and track them in the specified database.
Only plugins which have already been imported will be loaded. Only plugins which have already been imported will be loaded.
@ -1194,19 +1189,19 @@ def sanitize_host_name(name: str) -> str:
return re.sub('[^A-Za-z0-9]+', '-', name)[:63].strip('-') return re.sub('[^A-Za-z0-9]+', '-', name)[:63].strip('-')
def get_generic_type(base_type: t.Type, generic_base_type: t.Type[TValue]) -> t.Optional[t.Type[TValue]]: def get_generic_type[TValue](base_type: t.Type, generic_base_type: t.Type[TValue]) -> t.Optional[t.Type[TValue]]:
"""Return the generic type arg derived from the generic_base_type type that is associated with the base_type type, if any, otherwise return None.""" """Return the generic type arg derived from the generic_base_type type that is associated with the base_type type, if any, otherwise return None."""
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
type_arg = t.get_args(base_type.__orig_bases__[0])[0] type_arg = t.get_args(base_type.__orig_bases__[0])[0]
return None if isinstance(type_arg, generic_base_type) else type_arg return None if isinstance(type_arg, generic_base_type) else type_arg
def get_type_associations(base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> list[tuple[t.Type[TValue], t.Type[TBase]]]: def get_type_associations[TBase, TValue](base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> list[tuple[t.Type[TValue], t.Type[TBase]]]:
"""Create and return a list of tuples associating generic_base_type derived types with a corresponding base_type derived type.""" """Create and return a list of tuples associating generic_base_type derived types with a corresponding base_type derived type."""
return [item for item in [(get_generic_type(sc_type, generic_base_type), sc_type) for sc_type in get_subclasses(base_type)] if item[1]] return [item for item in [(get_generic_type(sc_type, generic_base_type), sc_type) for sc_type in get_subclasses(base_type)] if item[1]]
def get_type_map(base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> dict[t.Type[TValue], t.Type[TBase]]: def get_type_map[TBase, TValue](base_type: t.Type[TBase], generic_base_type: t.Type[TValue]) -> dict[t.Type[TValue], t.Type[TBase]]:
"""Create and return a mapping of generic_base_type derived types to base_type derived types.""" """Create and return a mapping of generic_base_type derived types to base_type derived types."""
return {item[0]: item[1] for item in get_type_associations(base_type, generic_base_type)} return {item[0]: item[1] for item in get_type_associations(base_type, generic_base_type)}
@ -1227,7 +1222,7 @@ def verify_sys_executable(path: str) -> t.Optional[str]:
return expected_executable return expected_executable
def type_guard(sequence: c.Sequence[t.Any], guard_type: t.Type[C]) -> t.TypeGuard[c.Sequence[C]]: def type_guard[C](sequence: c.Sequence[t.Any], guard_type: t.Type[C]) -> t.TypeGuard[c.Sequence[C]]:
""" """
Raises an exception if any item in the given sequence does not match the specified guard type. Raises an exception if any item in the given sequence does not match the specified guard type.
Use with assert so that type checkers are aware of the type guard. Use with assert so that type checkers are aware of the type guard.

@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import importlib
import os import os
import sys import sys
@ -29,8 +30,9 @@ def main(args=None):
if any(not os.get_blocking(handle.fileno()) for handle in (sys.stdin, sys.stdout, sys.stderr)): if any(not os.get_blocking(handle.fileno()) for handle in (sys.stdin, sys.stdout, sys.stderr)):
raise SystemExit('Standard input, output and error file handles must be blocking to run ansible-test.') raise SystemExit('Standard input, output and error file handles must be blocking to run ansible-test.')
# noinspection PyProtectedMember # avoid using import to hide it from mypy
from ansible_test._internal import main as cli_main internal = importlib.import_module('ansible_test._internal')
cli_main = getattr(internal, 'main')
cli_main(args) cli_main(args)

Loading…
Cancel
Save