Add OrderedSet class

pull/84134/head
Matt Martz 1 month ago
parent 56bab1d097
commit 1f9c69b43d
No known key found for this signature in database
GPG Key ID: 40832D88E9FC91D8

@ -0,0 +1,2 @@
minor_changes:
- Add new OrderedSet class for situations a unique ordered list is needed

@ -5,29 +5,30 @@
from __future__ import annotations
import collections.abc as _c
from contextlib import suppress as _suppress
from ansible.module_utils.six import binary_type, text_type
from ansible.module_utils.six.moves.collections_abc import Hashable, Mapping, MutableMapping, Sequence # pylint: disable=unused-import
import ansible.module_utils.compat.typing as _t
class ImmutableDict(Hashable, Mapping):
class ImmutableDict(_c.Hashable, _c.Mapping):
"""Dictionary that cannot be updated"""
def __init__(self, *args, **kwargs):
self._store = dict(*args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
self._store: dict[_c.Hashable, _t.Any] = dict(*args, **kwargs)
def __getitem__(self, key):
def __getitem__(self, key: _c.Hashable) -> _t.Any:
return self._store[key]
def __iter__(self):
def __iter__(self) -> _c.Iterator:
return self._store.__iter__()
def __len__(self):
def __len__(self) -> int:
return self._store.__len__()
def __hash__(self):
def __hash__(self) -> int:
return hash(frozenset(self.items()))
def __eq__(self, other):
def __eq__(self, other: _t.Any) -> bool:
try:
if self.__hash__() == hash(other):
return True
@ -36,10 +37,10 @@ class ImmutableDict(Hashable, Mapping):
return False
def __repr__(self):
def __repr__(self) -> str:
return 'ImmutableDict({0})'.format(repr(self._store))
def union(self, overriding_mapping):
def union(self, overriding_mapping: _c.Mapping) -> ImmutableDict:
"""
Create an ImmutableDict as a combination of the original and overriding_mapping
@ -51,7 +52,7 @@ class ImmutableDict(Hashable, Mapping):
"""
return ImmutableDict(self._store, **overriding_mapping)
def difference(self, subtractive_iterable):
def difference(self, subtractive_iterable: _c.Iterable) -> ImmutableDict:
"""
Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable
@ -64,13 +65,73 @@ class ImmutableDict(Hashable, Mapping):
return ImmutableDict((k, self._store[k]) for k in keys)
def is_string(seq):
class OrderedSet(_c.MutableSet):
def __init__(
self,
iterable: _c.Iterable[_c.Hashable] | None = None,
/
) -> None:
self._data: dict[_c.Hashable, None]
if iterable is None:
self._data = {}
else:
self._data = dict.fromkeys(iterable)
def __repr__(self, /) -> str:
return f'OrderedSet({list(self._data)!r})'
def __eq__(self, other: _t.Any, /) -> bool:
if not isinstance(other, OrderedSet):
return NotImplemented
return len(self) == len(other) and tuple(self) == tuple(other)
def __contains__(self, x: _c.Hashable, /) -> bool:
return x in self._data
def __iter__(self, /) -> _c.Iterator:
return self._data.__iter__()
def __len__(self, /) -> int:
return self._data.__len__()
def add(self, value: _c.Hashable, /) -> None:
self._data[value] = None
def discard(self, value: _c.Hashable, /) -> None:
with _suppress(KeyError):
del self._data[value]
def clear(self, /) -> None:
self._data.clear()
def copy(self, /) -> OrderedSet:
return OrderedSet(self._data.copy())
def __and__(self, other: _c.Container, /) -> OrderedSet:
# overridden, because the ABC produces an arguably unexpected sorting
return OrderedSet(value for value in self if value in other)
difference = _c.MutableSet.__sub__
difference_update = _c.MutableSet.__isub__
intersection = __and__
__rand__ = _c.MutableSet.__and__
intersection_update = _c.MutableSet.__iand__
issubset = _c.MutableSet.__le__
issuperset = _c.MutableSet.__ge__
symmetric_difference = _c.MutableSet.__xor__
symmetric_difference_update = _c.MutableSet.__ixor__
union = _c.MutableSet.__or__
update = _c.MutableSet.__ior__
def is_string(seq: _c.Iterable) -> bool:
"""Identify whether the input has a string-like type (including bytes)."""
# AnsibleVaultEncryptedUnicode inherits from Sequence, but is expected to be a string like object
return isinstance(seq, (text_type, binary_type)) or getattr(seq, '__ENCRYPTED__', False)
return isinstance(seq, (str, bytes)) or getattr(seq, '__ENCRYPTED__', False)
def is_iterable(seq, include_strings=False):
def is_iterable(seq: _c.Iterable, include_strings: bool = False) -> bool:
"""Identify whether the input is an iterable."""
if not include_strings and is_string(seq):
return False
@ -82,7 +143,7 @@ def is_iterable(seq, include_strings=False):
return False
def is_sequence(seq, include_strings=False):
def is_sequence(seq: _c.Iterable, include_strings: bool = False) -> bool:
"""Identify whether the input is a sequence.
Strings and bytes are not sequences here,
@ -93,10 +154,10 @@ def is_sequence(seq, include_strings=False):
if not include_strings and is_string(seq):
return False
return isinstance(seq, Sequence)
return isinstance(seq, _c.Sequence)
def count(seq):
def count(seq: _c.Iterable) -> dict[_c.Hashable, int]:
"""Returns a dictionary with the number of appearances of each element of the iterable.
Resembles the collections.Counter class functionality. It is meant to be used when the
@ -105,7 +166,13 @@ def count(seq):
"""
if not is_iterable(seq):
raise Exception('Argument provided is not an iterable')
counters = dict()
counters: dict[_c.Hashable, int] = {}
for elem in seq:
counters[elem] = counters.get(elem, 0) + 1
return counters
Hashable = _c.Hashable
Mapping = _c.Mapping
MutableMapping = _c.MutableMapping
Sequence = _c.Sequence

@ -65,7 +65,7 @@ else:
GzipFile = gzip.GzipFile # type: ignore[assignment,misc]
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.common.collections import Mapping, is_sequence
from ansible.module_utils.common.collections import Mapping, OrderedSet, is_sequence
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
try:
@ -516,7 +516,7 @@ def get_ca_certs(cafile=None, capath=None):
# Using a dict, instead of a set for order, the value is meaningless and will be None
# Not directly using a bytearray to avoid duplicates with fast lookup
cadata = {}
cadata = OrderedSet()
# If cafile is passed, we are only using that for verification,
# don't add additional ca certs
@ -525,7 +525,7 @@ def get_ca_certs(cafile=None, capath=None):
with open(to_bytes(cafile, errors='surrogate_or_strict'), 'r', errors='surrogateescape') as f:
for pem in extract_pem_certs(f.read()):
b_der = ssl.PEM_cert_to_DER_cert(pem)
cadata[b_der] = None
cadata.add(b_der)
return bytearray().join(cadata), paths_checked
default_verify_paths = ssl.get_default_verify_paths()
@ -576,7 +576,7 @@ def get_ca_certs(cafile=None, capath=None):
try:
for pem in extract_pem_certs(cert):
b_der = ssl.PEM_cert_to_DER_cert(pem)
cadata[b_der] = None
cadata.add(b_der)
except Exception:
continue
except (OSError, IOError):

@ -57,6 +57,7 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/__init__.py',
'ansible/module_utils/common/arg_spec.py',
'ansible/module_utils/compat/__init__.py',
'ansible/module_utils/compat/selinux.py',
'ansible/module_utils/compat/typing.py',
'ansible/module_utils/distro/__init__.py',
'ansible/module_utils/distro/_distro.py',
'ansible/module_utils/errors.py',

@ -8,7 +8,7 @@ from __future__ import annotations
import pytest
from collections.abc import Sequence
from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence
from ansible.module_utils.common.collections import ImmutableDict, OrderedSet, is_iterable, is_sequence
class SeqStub:
@ -159,3 +159,34 @@ class TestImmutableDict:
actual_repr = repr(imdict)
expected_repr = "ImmutableDict({0})".format(initial_data_repr)
assert actual_repr == expected_repr
class TestOrderedSet:
def test_sorting(self):
expected = ['foo', 'bar', 'baz']
assert list(OrderedSet(expected)) == expected
def test_sorting_add_discard(self):
o = OrderedSet()
o.add('foo')
o.update(['bar', 'baz'])
assert list(o) == ['foo', 'bar', 'baz']
o.discard('foo')
assert list(o) == ['bar', 'baz']
o.add('foo')
assert list(o) == ['bar', 'baz', 'foo']
def test_sorting_set_opts(self):
o1 = OrderedSet(['foo', 'bar', 'baz', 'qux'])
o2 = OrderedSet(['qux', 'bar', 'ham', 'sandwich'])
difference = o1 - o2
intersect = o1 & o2
union = o1 | o2
symmetric_difference = o1 ^ o2
assert list(difference) == ['foo', 'baz']
assert list(intersect) == ['bar', 'qux']
assert list(union) == ['foo', 'bar', 'baz', 'qux', 'ham', 'sandwich']
assert list(symmetric_difference) == ['foo', 'baz', 'ham', 'sandwich']

Loading…
Cancel
Save