From 1f9c69b43d7ae003a788f05b0459e2621dab7ba5 Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Thu, 17 Oct 2024 12:02:41 -0500 Subject: [PATCH] Add OrderedSet class --- changelogs/fragments/orderedset.yml | 2 + .../module_utils/common/collections.py | 107 ++++++++++++++---- lib/ansible/module_utils/urls.py | 8 +- .../module_common/test_recursive_finder.py | 1 + .../module_utils/common/test_collections.py | 33 +++++- 5 files changed, 126 insertions(+), 25 deletions(-) create mode 100644 changelogs/fragments/orderedset.yml diff --git a/changelogs/fragments/orderedset.yml b/changelogs/fragments/orderedset.yml new file mode 100644 index 00000000000..52dd1941d9a --- /dev/null +++ b/changelogs/fragments/orderedset.yml @@ -0,0 +1,2 @@ +minor_changes: +- Add new OrderedSet class for situations a unique ordered list is needed diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py index 28c53e14e2c..a48a3b3fe8a 100644 --- a/lib/ansible/module_utils/common/collections.py +++ b/lib/ansible/module_utils/common/collections.py @@ -5,29 +5,30 @@ from __future__ import annotations +import collections.abc as _c +from contextlib import suppress as _suppress -from ansible.module_utils.six import binary_type, text_type -from ansible.module_utils.six.moves.collections_abc import Hashable, Mapping, MutableMapping, Sequence # pylint: disable=unused-import +import ansible.module_utils.compat.typing as _t -class ImmutableDict(Hashable, Mapping): +class ImmutableDict(_c.Hashable, _c.Mapping): """Dictionary that cannot be updated""" - def __init__(self, *args, **kwargs): - self._store = dict(*args, **kwargs) + def __init__(self, *args, **kwargs) -> None: + self._store: dict[_c.Hashable, _t.Any] = dict(*args, **kwargs) - def __getitem__(self, key): + def __getitem__(self, key: _c.Hashable) -> _t.Any: return self._store[key] - def __iter__(self): + def __iter__(self) -> _c.Iterator: return self._store.__iter__() - def __len__(self): + def __len__(self) -> int: return self._store.__len__() - def __hash__(self): + def __hash__(self) -> int: return hash(frozenset(self.items())) - def __eq__(self, other): + def __eq__(self, other: _t.Any) -> bool: try: if self.__hash__() == hash(other): return True @@ -36,10 +37,10 @@ class ImmutableDict(Hashable, Mapping): return False - def __repr__(self): + def __repr__(self) -> str: return 'ImmutableDict({0})'.format(repr(self._store)) - def union(self, overriding_mapping): + def union(self, overriding_mapping: _c.Mapping) -> ImmutableDict: """ Create an ImmutableDict as a combination of the original and overriding_mapping @@ -51,7 +52,7 @@ class ImmutableDict(Hashable, Mapping): """ return ImmutableDict(self._store, **overriding_mapping) - def difference(self, subtractive_iterable): + def difference(self, subtractive_iterable: _c.Iterable) -> ImmutableDict: """ Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable @@ -64,13 +65,73 @@ class ImmutableDict(Hashable, Mapping): return ImmutableDict((k, self._store[k]) for k in keys) -def is_string(seq): +class OrderedSet(_c.MutableSet): + def __init__( + self, + iterable: _c.Iterable[_c.Hashable] | None = None, + / + ) -> None: + + self._data: dict[_c.Hashable, None] + if iterable is None: + self._data = {} + else: + self._data = dict.fromkeys(iterable) + + def __repr__(self, /) -> str: + return f'OrderedSet({list(self._data)!r})' + + def __eq__(self, other: _t.Any, /) -> bool: + if not isinstance(other, OrderedSet): + return NotImplemented + return len(self) == len(other) and tuple(self) == tuple(other) + + def __contains__(self, x: _c.Hashable, /) -> bool: + return x in self._data + + def __iter__(self, /) -> _c.Iterator: + return self._data.__iter__() + + def __len__(self, /) -> int: + return self._data.__len__() + + def add(self, value: _c.Hashable, /) -> None: + self._data[value] = None + + def discard(self, value: _c.Hashable, /) -> None: + with _suppress(KeyError): + del self._data[value] + + def clear(self, /) -> None: + self._data.clear() + + def copy(self, /) -> OrderedSet: + return OrderedSet(self._data.copy()) + + def __and__(self, other: _c.Container, /) -> OrderedSet: + # overridden, because the ABC produces an arguably unexpected sorting + return OrderedSet(value for value in self if value in other) + + difference = _c.MutableSet.__sub__ + difference_update = _c.MutableSet.__isub__ + intersection = __and__ + __rand__ = _c.MutableSet.__and__ + intersection_update = _c.MutableSet.__iand__ + issubset = _c.MutableSet.__le__ + issuperset = _c.MutableSet.__ge__ + symmetric_difference = _c.MutableSet.__xor__ + symmetric_difference_update = _c.MutableSet.__ixor__ + union = _c.MutableSet.__or__ + update = _c.MutableSet.__ior__ + + +def is_string(seq: _c.Iterable) -> bool: """Identify whether the input has a string-like type (including bytes).""" # AnsibleVaultEncryptedUnicode inherits from Sequence, but is expected to be a string like object - return isinstance(seq, (text_type, binary_type)) or getattr(seq, '__ENCRYPTED__', False) + return isinstance(seq, (str, bytes)) or getattr(seq, '__ENCRYPTED__', False) -def is_iterable(seq, include_strings=False): +def is_iterable(seq: _c.Iterable, include_strings: bool = False) -> bool: """Identify whether the input is an iterable.""" if not include_strings and is_string(seq): return False @@ -82,7 +143,7 @@ def is_iterable(seq, include_strings=False): return False -def is_sequence(seq, include_strings=False): +def is_sequence(seq: _c.Iterable, include_strings: bool = False) -> bool: """Identify whether the input is a sequence. Strings and bytes are not sequences here, @@ -93,10 +154,10 @@ def is_sequence(seq, include_strings=False): if not include_strings and is_string(seq): return False - return isinstance(seq, Sequence) + return isinstance(seq, _c.Sequence) -def count(seq): +def count(seq: _c.Iterable) -> dict[_c.Hashable, int]: """Returns a dictionary with the number of appearances of each element of the iterable. Resembles the collections.Counter class functionality. It is meant to be used when the @@ -105,7 +166,13 @@ def count(seq): """ if not is_iterable(seq): raise Exception('Argument provided is not an iterable') - counters = dict() + counters: dict[_c.Hashable, int] = {} for elem in seq: counters[elem] = counters.get(elem, 0) + 1 return counters + + +Hashable = _c.Hashable +Mapping = _c.Mapping +MutableMapping = _c.MutableMapping +Sequence = _c.Sequence diff --git a/lib/ansible/module_utils/urls.py b/lib/ansible/module_utils/urls.py index c90f0b78fd4..20dd86f1aad 100644 --- a/lib/ansible/module_utils/urls.py +++ b/lib/ansible/module_utils/urls.py @@ -65,7 +65,7 @@ else: GzipFile = gzip.GzipFile # type: ignore[assignment,misc] from ansible.module_utils.basic import missing_required_lib -from ansible.module_utils.common.collections import Mapping, is_sequence +from ansible.module_utils.common.collections import Mapping, OrderedSet, is_sequence from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text try: @@ -516,7 +516,7 @@ def get_ca_certs(cafile=None, capath=None): # Using a dict, instead of a set for order, the value is meaningless and will be None # Not directly using a bytearray to avoid duplicates with fast lookup - cadata = {} + cadata = OrderedSet() # If cafile is passed, we are only using that for verification, # don't add additional ca certs @@ -525,7 +525,7 @@ def get_ca_certs(cafile=None, capath=None): with open(to_bytes(cafile, errors='surrogate_or_strict'), 'r', errors='surrogateescape') as f: for pem in extract_pem_certs(f.read()): b_der = ssl.PEM_cert_to_DER_cert(pem) - cadata[b_der] = None + cadata.add(b_der) return bytearray().join(cadata), paths_checked default_verify_paths = ssl.get_default_verify_paths() @@ -576,7 +576,7 @@ def get_ca_certs(cafile=None, capath=None): try: for pem in extract_pem_certs(cert): b_der = ssl.PEM_cert_to_DER_cert(pem) - cadata[b_der] = None + cadata.add(b_der) except Exception: continue except (OSError, IOError): diff --git a/test/units/executor/module_common/test_recursive_finder.py b/test/units/executor/module_common/test_recursive_finder.py index 92d7c206e0b..befd1b0f56b 100644 --- a/test/units/executor/module_common/test_recursive_finder.py +++ b/test/units/executor/module_common/test_recursive_finder.py @@ -57,6 +57,7 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/__init__.py', 'ansible/module_utils/common/arg_spec.py', 'ansible/module_utils/compat/__init__.py', 'ansible/module_utils/compat/selinux.py', + 'ansible/module_utils/compat/typing.py', 'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/_distro.py', 'ansible/module_utils/errors.py', diff --git a/test/units/module_utils/common/test_collections.py b/test/units/module_utils/common/test_collections.py index 381d583004c..58027ae91c3 100644 --- a/test/units/module_utils/common/test_collections.py +++ b/test/units/module_utils/common/test_collections.py @@ -8,7 +8,7 @@ from __future__ import annotations import pytest from collections.abc import Sequence -from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence +from ansible.module_utils.common.collections import ImmutableDict, OrderedSet, is_iterable, is_sequence class SeqStub: @@ -159,3 +159,34 @@ class TestImmutableDict: actual_repr = repr(imdict) expected_repr = "ImmutableDict({0})".format(initial_data_repr) assert actual_repr == expected_repr + + +class TestOrderedSet: + def test_sorting(self): + expected = ['foo', 'bar', 'baz'] + assert list(OrderedSet(expected)) == expected + + def test_sorting_add_discard(self): + o = OrderedSet() + o.add('foo') + o.update(['bar', 'baz']) + assert list(o) == ['foo', 'bar', 'baz'] + + o.discard('foo') + assert list(o) == ['bar', 'baz'] + o.add('foo') + assert list(o) == ['bar', 'baz', 'foo'] + + def test_sorting_set_opts(self): + o1 = OrderedSet(['foo', 'bar', 'baz', 'qux']) + o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich']) + + difference = o1 - o2 + intersect = o1 & o2 + union = o1 | o2 + symmetric_difference = o1 ^ o2 + + assert list(difference) == ['foo', 'baz'] + assert list(intersect) == ['bar', 'qux'] + assert list(union) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich'] + assert list(symmetric_difference) == ['foo', 'baz', 'ham', 'sandwich']