You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ansible/test/units/utils/test_serialization_profiles.py

382 lines
14 KiB
Python

# DTFIX-FUTURE: these tests need to be split so they can run under both module_utils and controller contexts
from __future__ import annotations
import contextlib
import dataclasses
import datetime
import hashlib
import itertools
import json
import pathlib
import pkgutil
import pprint
import typing as t
import pytest
from ansible.module_utils._internal._json import _profiles as target_serialization_profiles
from ansible.module_utils._internal import _json as _serialization
from ansible.module_utils._internal._datatag import AnsibleDatatagBase, NotTaggableError, AnsibleTagHelper
from ansible.module_utils._internal._datatag._tags import Deprecated
from ansible._internal._templating._lazy_containers import _AnsibleLazyTemplateMixin
from ansible._internal._templating._engine import TemplateEngine, TemplateOptions
from ansible._internal._templating._utils import TemplateContext
from ansible._internal._datatag._tags import TrustedAsTemplate, VaultedValue, Origin, SourceWasEncrypted
from ansible._internal._json import _profiles as controller_serialization_profiles
from ansible.module_utils.common.json import get_encoder, get_decoder
from ansible.module_utils._internal._json._profiles import _fallback_to_str
from ansible._internal._json._profiles import _cache_persistence
from ansible.errors import AnsibleRuntimeError
from ..mock.custom_types import CustomMapping, CustomSequence, CustomStr, CustomInt, CustomFloat
basic_values = (
None,
True,
1,
1.1,
'hi',
'汉语', # non-ASCII string
b'hi',
datetime.datetime(2024, 1, 2, 3, 4, 5, 6, datetime.timezone.utc, fold=1),
datetime.time(1, 2, 3, 4, datetime.timezone.utc, fold=1),
datetime.date(2024, 1, 2),
(1,),
[1],
CustomSequence([1]),
CustomStr('hello'),
CustomInt(42),
CustomFloat(42.0),
{1},
dict(a=1),
CustomMapping(dict(a=1)),
{(1, 2): "three"}, # hashable non-scalar key
{frozenset((1, 2)): "three"}, # hashable non-scalar key
)
# DTFIX5: we need tests for recursion, specifically things like custom sequences and mappings when:
# 1) using the legacy serializer
# 2) containing types in the type map, such as tagged values
# e.g. -- does trust inversion get applied to a value inside a custom sequence or mapping
tag_values = {
Deprecated: Deprecated(msg='x'), # DTFIX5: we need more exhaustive testing of the values supported by this tag to ensure schema ID is robust
TrustedAsTemplate: TrustedAsTemplate(),
Origin: Origin(path='/tmp/x', line_num=1, col_num=2, description='y'),
VaultedValue: VaultedValue(ciphertext='x'),
SourceWasEncrypted: SourceWasEncrypted(),
}
def test_cache_persistence_schema() -> None:
"""
Check the schema ID for the cache_persistence schema to ensure it is updated when the schema changes.
Failure to update the schema ID will result in serialization/deserialiation failures for persisted data for things like cache plugins.
This test is only as comprehensive as these unit tests, so ensure profile data types are thoroughly covered.
If additional capabilities are added to the cache_persistence profile which are not tested, they will go undetected, leading to runtime failures.
"""
# DTFIX5: update tests to ensure new fields on contracts will fail this test if they have defaults which are omitted from serialization
# one possibility: monkeypatch the default field value omission away so that any new field will invalidate the schema
# DTFIX5: ensure all types/attrs included in _profiles._common_module_response_types are represented here, since they can appear in cached responses
expected_schema_id = 1
expected_schema_hash = "bf52e60cf1d25a3f8b6bfdf734781ee07cfe46e94189d2f538815c5000b617c6"
test_hash = hashlib.sha256()
test_hash.update(pathlib.Path(DataSet.PROFILE_DIR / _cache_persistence._Profile.profile_name).with_suffix('.txt').read_bytes())
actual_schema_id = _cache_persistence._Profile.schema_id
actual_schema_hash = test_hash.hexdigest()
next_schema_id = actual_schema_id + 1
schema_check_failure_instructions = f"""The cache_persistence schema check hash has changed. The solution depends on the reason why:
1) The schema and tests have changed:
i. Increment `ansible._internal._json._profiles._cache_persistence._Profile.schema_id` to {next_schema_id}.
ii. Update `expected_schema_id` to {next_schema_id}.
iii. Update `expected_schema_hash` to {actual_schema_hash!r}.
2) The schema is unchanged, but the tests have changed:
i. Double-check that the schema really hasn't changed.
ii. Don't forget about added/changed/removed types as well as fields on those types.
iii. Update `expected_schema_hash` to {actual_schema_hash!r}.
"""
if actual_schema_id != expected_schema_id:
raise Exception(f"The actual schema ID {actual_schema_id} does not match the expected schema ID {expected_schema_id}.")
if actual_schema_hash != expected_schema_hash:
raise Exception(schema_check_failure_instructions)
def get_profile_names() -> tuple[str, ...]:
packages = (target_serialization_profiles, controller_serialization_profiles)
names = []
for package in packages:
modules = list(pkgutil.iter_modules(package.__path__, f'{package.__name__}.'))
assert modules # ensure at least one serialization profile module was found
for module in modules:
names.append(_serialization.get_serialization_profile(module.name).profile_name)
return tuple(sorted(names))
@dataclasses.dataclass(frozen=True)
class _TestParameters:
profile_name: str
value: t.Any
tags: tuple[AnsibleDatatagBase, ...] = ()
lazy: bool = False
def __hash__(self):
return hash((self.profile_name, repr(self.value), self.tags))
def __repr__(self):
fields = ((field, getattr(self, field.name)) for field in dataclasses.fields(self))
args = (f'{f.name}={v!r}' for f, v in fields if v != f.default)
return f"{type(self).__name__}({', '.join(args)})"
def get_test_output(self) -> _TestOutput:
encoder = get_encoder(self.profile_name)
decoder = get_decoder(self.profile_name)
ctx = TemplateContext(
template_value=self.value,
templar=TemplateEngine(),
options=TemplateOptions.DEFAULT,
stop_on_template=False
) if self.lazy else contextlib.nullcontext()
with ctx:
try:
value = AnsibleTagHelper.tag(self.value, self.tags)
except NotTaggableError:
value = self.value
if self.lazy:
value = _AnsibleLazyTemplateMixin._try_create(value)
payload: str | Exception
try:
payload = json.dumps(value, cls=encoder)
except Exception as ex:
payload = ex
round_trip = None
else:
try:
round_trip = json.loads(payload, cls=decoder)
except Exception as ex:
round_trip = ex
return _TestOutput(
payload=payload,
round_trip=AnsibleTagHelper.as_native_type(round_trip),
tags=tuple(sorted(AnsibleTagHelper.tags(round_trip), key=lambda item: type(item).__name__)),
)
@dataclasses.dataclass(frozen=True)
class _TestOutput:
payload: str | Exception
round_trip: t.Any
tags: tuple[AnsibleDatatagBase, ...]
@dataclasses.dataclass(frozen=True)
class _TestCase:
parameters: _TestParameters
expected: _TestOutput
def __str__(self) -> str:
parts = [f'profile={self.parameters.profile_name}', f'value={self.parameters.value}']
if self.parameters.tags:
parts.append(f"tags={','.join(sorted(type(obj).__name__ for obj in self.parameters.tags))}")
if self.parameters.lazy:
parts.append('lazy')
return '; '.join(parts)
class DataSet:
PROFILE_DIR = pathlib.Path(__file__).parent / 'expected_serialization_profiles'
def __init__(self, generate: bool) -> None:
self.data: dict[_TestParameters, _TestOutput] = {}
self.path = self.PROFILE_DIR
self.generate = generate
def load(self) -> None:
if self.generate:
return
for source in self.path.glob('*.txt'):
self.data.update(eval(source.read_text()))
def save(self) -> None:
if not self.generate:
return
sorted_items = sorted(self.data.items(), key=lambda o: o[0].profile_name) # additional items appended to the end means the data set is unsorted
grouped_data_set = {key: dict(gen) for key, gen in itertools.groupby(sorted_items, key=lambda o: o[0].profile_name)}
for group_name, profiles in grouped_data_set.items():
content = self.generate_content(profiles)
(self.path / f'{group_name}.txt').write_text(content)
@staticmethod
def generate_content(profiles: dict[_TestParameters, _TestOutput]) -> str:
content = ["{"]
# loop the dictionary entries manually to ensure one entry per line
for key, value in profiles.items():
key_pprint = pprint.pformat(key, width=10000, indent=0, sort_dicts=False)
value_pprint = pprint.pformat(value, width=10000, indent=0, sort_dicts=False)
content.append(f"{key_pprint}: {value_pprint},")
content.append("}")
return '\n'.join(content) + '\n'
def fetch_or_create_expected(self, test_params: _TestParameters) -> _TestOutput:
if self.generate:
output = self.data[test_params] = test_params.get_test_output()
else:
try:
output = self.data[test_params]
except KeyError:
raise Exception(f'Missing {test_params} in data set. Use `generate=True` to update the data set and then review the changes.') from None
return output
class ProfileHelper:
def __init__(self, profile_name: str) -> None:
self.profile_name = profile_name
profile = _serialization.get_serialization_profile(profile_name)
supported_tags = {obj: None for obj in profile.serialize_map if issubclass(obj, AnsibleDatatagBase)}
if supported_tags:
self.supported_tag_values = tuple(tag_value for tag_type, tag_value in tag_values.items() if tag_type in supported_tags)
if not self.supported_tag_values:
raise Exception(f'Profile {profile} supports tags {supported_tags}, but no supported tag value is available.')
else:
self.supported_tag_values = tuple()
self.unsupported_tag_value = next((tag_value for tag_type, tag_value in tag_values.items() if tag_type not in supported_tags), None)
if not self.unsupported_tag_value and profile.profile_name != _cache_persistence._Profile.profile_name:
raise Exception(f'Profile {profile} supports tags {supported_tags}, but no unsupported tag value is available.')
def create_parameters_from_values(self, *values: t.Any) -> list[_TestParameters]:
return list(itertools.chain.from_iterable(self.create_parameters_from_value(value) for value in values))
def create_parameters_from_value(self, value: t.Any) -> list[_TestParameters]:
test_parameters: list[_TestParameters] = [
_TestParameters(
profile_name=self.profile_name,
value=value,
)
]
if self.supported_tag_values:
test_parameters.append(_TestParameters(
profile_name=self.profile_name,
value=value,
tags=self.supported_tag_values,
))
if self.unsupported_tag_value:
test_parameters.append(_TestParameters(
profile_name=self.profile_name,
value=value,
tags=(self.unsupported_tag_value,),
))
# test lazy containers on all non m2c profiles
if not self.profile_name.endswith("_m2c") and isinstance(value, (list, dict)):
test_parameters.extend([dataclasses.replace(p, lazy=True) for p in test_parameters])
return test_parameters
additional_test_parameters: list[_TestParameters] = []
# DTFIX5: need better testing for containers, especially for tagged values in containers
additional_test_parameters.extend(ProfileHelper(_fallback_to_str._Profile.profile_name).create_parameters_from_values(
b'\x00', # valid utf-8 strict, JSON escape sequence required
b'\x80', # utf-8 strict decoding fails, forcing the use of an error handler such as surrogateescape, JSON escape sequence required
'\udc80', # same as above, but already a string (verify that the string version is handled the same as the bytes version)
{1: "1"}, # integer key
{b'hi': "1"}, # bytes key
{TrustedAsTemplate().tag(b'hi'): "2"}, # tagged bytes key
{(b'hi',): 3}, # tuple[bytes] key
))
_generate = False
"""Set to True to regenerate all test data; a test failure will occur until it is set back to False."""
def get_test_cases() -> list[_TestCase]:
data_set = DataSet(generate=_generate)
data_set.load()
test_parameters: list[_TestParameters] = []
for profile_name in get_profile_names():
helper = ProfileHelper(profile_name)
for value in basic_values:
test_parameters.extend(helper.create_parameters_from_value(value))
test_parameters.extend(additional_test_parameters)
test_cases = [_TestCase(parameters=parameters, expected=data_set.fetch_or_create_expected(parameters)) for parameters in test_parameters]
data_set.save()
return test_cases
@pytest.mark.parametrize("test_case", get_test_cases(), ids=str)
def test_profile(test_case: _TestCase) -> None:
output = test_case.parameters.get_test_output()
if isinstance(output.payload, Exception):
if type(output.payload) is not type(test_case.expected.payload):
raise Exception('unexpected exception') from output.payload
assert str(output.payload) == str(test_case.expected.payload)
else:
assert output.payload == test_case.expected.payload
assert type(output.round_trip) is type(test_case.expected.round_trip)
if isinstance(output.round_trip, AnsibleRuntimeError):
assert str(output.round_trip._original_message) == str(test_case.expected.round_trip._original_message)
else:
assert output.round_trip == test_case.expected.round_trip
assert not set(output.tags).symmetric_difference(test_case.expected.tags)
def test_not_generate_mode():
assert not _generate, "set _generate=False to statically test expected behavior"