diff --git a/changelogs/fragments/to-text-to-bytes.yml b/changelogs/fragments/to-text-to-bytes.yml new file mode 100644 index 00000000000..2345539bb94 --- /dev/null +++ b/changelogs/fragments/to-text-to-bytes.yml @@ -0,0 +1,2 @@ +minor_changes: + - Python type hints applied to ``to_text`` and ``to_bytes`` functions for better type hint interactions with code utilizing these functions. diff --git a/lib/ansible/galaxy/collection/__init__.py b/lib/ansible/galaxy/collection/__init__.py index 38737468dcd..4056f0c177e 100644 --- a/lib/ansible/galaxy/collection/__init__.py +++ b/lib/ansible/galaxy/collection/__init__.py @@ -339,12 +339,12 @@ def verify_local_collection(local_collection, remote_collection, artifacts_manag ] # Find any paths not in the FILES.json - for root, dirs, files in os.walk(b_collection_path): - for name in files: + for root, dirs, filenames in os.walk(b_collection_path): + for name in filenames: full_path = os.path.join(root, name) path = to_text(full_path[len(b_collection_path) + 1::], errors='surrogate_or_strict') if any(fnmatch.fnmatch(full_path, b_pattern) for b_pattern in b_ignore_patterns): - display.v("Ignoring verification for %s" % full_path) + display.v("Ignoring verification for %s" % to_text(full_path)) continue if full_path not in collection_files: diff --git a/lib/ansible/module_utils/common/text/converters.py b/lib/ansible/module_utils/common/text/converters.py index ac044511853..0bc0cd4f252 100644 --- a/lib/ansible/module_utils/common/text/converters.py +++ b/lib/ansible/module_utils/common/text/converters.py @@ -8,9 +8,9 @@ from __future__ import annotations import codecs import json +from ansible.module_utils.compat import typing as _t from ansible.module_utils._internal import _no_six - try: codecs.lookup_error('surrogateescape') HAS_SURROGATEESCAPE = True @@ -22,8 +22,54 @@ _COMPOSED_ERROR_HANDLERS = frozenset((None, 'surrogate_or_replace', 'surrogate_or_strict', 'surrogate_then_replace')) +_T = _t.TypeVar('_T') + +_NonStringPassthru: _t.TypeAlias = _t.Literal['passthru'] +_NonStringOther: _t.TypeAlias = _t.Literal['simplerepr', 'empty', 'strict'] +_NonStringAll: _t.TypeAlias = _t.Union[_NonStringPassthru, _NonStringOther] + + +@_t.overload +def to_bytes( + obj: object, + encoding: str = 'utf-8', + errors: str | None = None, +) -> bytes: ... + + +@_t.overload +def to_bytes( + obj: bytes | str, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringPassthru = 'passthru', +) -> bytes: ... + + +@_t.overload +def to_bytes( + obj: _T, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringPassthru = 'passthru', +) -> _T: ... + + +@_t.overload +def to_bytes( + obj: object, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringOther = 'simplerepr', +) -> bytes: ... -def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): + +def to_bytes( + obj: _T, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringAll = 'simplerepr' +) -> _T | bytes: """Make sure that a string is a byte string :arg obj: An object to make sure is a byte string. In most cases this @@ -81,7 +127,7 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): string is valid in the specified encoding. If it's important that the byte string is in the specified encoding do:: - encoded_string = to_bytes(to_text(input_string, 'latin-1'), 'utf-8') + encoded_string = to_bytes(to_text(input_string, encoding='latin-1'), encoding='utf-8') .. version_changed:: 2.3 @@ -126,21 +172,60 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): value = repr(obj) except UnicodeError: # Giving up - return to_bytes('') + return b'' elif nonstring == 'passthru': return obj elif nonstring == 'empty': - # python2.4 doesn't have b'' - return to_bytes('') + return b'' elif nonstring == 'strict': raise TypeError('obj must be a string type') else: raise TypeError('Invalid value %s for to_bytes\' nonstring parameter' % nonstring) - return to_bytes(value, encoding, errors) + return to_bytes(value, encoding=encoding, errors=errors) + + +@_t.overload +def to_text( + obj: object, + encoding: str = 'utf-8', + errors: str | None = None, +) -> str: ... + + +@_t.overload +def to_text( + obj: str | bytes, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringPassthru = 'passthru', +) -> str: ... + + +@_t.overload +def to_text( + obj: _T, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringPassthru = 'passthru', +) -> _T: ... + + +@_t.overload +def to_text( + obj: object, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringOther = 'simplerepr', +) -> str: ... -def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): +def to_text( + obj: _T, + encoding: str = 'utf-8', + errors: str | None = None, + nonstring: _NonStringAll = 'simplerepr' +) -> _T | str: """Make sure that a string is a text string :arg obj: An object to make sure is a text string. In most cases this @@ -218,17 +303,17 @@ def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): value = repr(obj) except UnicodeError: # Giving up - return u'' + return '' elif nonstring == 'passthru': return obj elif nonstring == 'empty': - return u'' + return '' elif nonstring == 'strict': raise TypeError('obj must be a string type') else: raise TypeError('Invalid value %s for to_text\'s nonstring parameter' % nonstring) - return to_text(value, encoding, errors) + return to_text(value, encoding=encoding, errors=errors) to_native = to_text diff --git a/lib/ansible/parsing/dataloader.py b/lib/ansible/parsing/dataloader.py index af470340832..22deaa606cd 100644 --- a/lib/ansible/parsing/dataloader.py +++ b/lib/ansible/parsing/dataloader.py @@ -31,7 +31,7 @@ display = Display() # Tries to determine if a path is inside a role, last dir must be 'tasks' # this is not perfect but people should really avoid 'tasks' dirs outside roles when using Ansible. -RE_TASKS = re.compile(u'(?:^|%s)+tasks%s?$' % (os.path.sep, os.path.sep)) +RE_TASKS = re.compile('(?:^|%s)+tasks%s?$' % (os.path.sep, os.path.sep)) class DataLoader: @@ -53,23 +53,22 @@ class DataLoader: ds = dl.load_from_file('/path/to/file') """ - def __init__(self): + def __init__(self) -> None: - self._basedir = '.' + self._basedir: str = '.' # NOTE: not effective with forks as the main copy does not get updated. # avoids rereading files - self._FILE_CACHE = dict() + self._FILE_CACHE: dict[str, object] = {} # NOTE: not thread safe, also issues with forks not returning data to main proc # so they need to be cleaned independently. See WorkerProcess for example. # used to keep track of temp files for cleaning - self._tempfiles = set() + self._tempfiles: set[str] = set() # initialize the vault stuff with an empty password # TODO: replace with a ref to something that can get the password # a creds/auth provider - self._vaults = {} self._vault = VaultLib() self.set_vault_secrets(None) @@ -229,23 +228,19 @@ class DataLoader: def set_basedir(self, basedir: str) -> None: """ sets the base directory, used to find files when a relative path is given """ - - if basedir is not None: - self._basedir = to_text(basedir) + self._basedir = basedir def path_dwim(self, given: str) -> str: """ make relative paths work like folks expect. """ - given = to_text(given, errors='surrogate_or_strict') given = unquote(given) - if given.startswith(to_text(os.path.sep)) or given.startswith(u'~'): + if given.startswith(os.path.sep) or given.startswith('~'): path = given else: - basedir = to_text(self._basedir, errors='surrogate_or_strict') - path = os.path.join(basedir, given) + path = os.path.join(self._basedir, given) return unfrackpath(path, follow=False) @@ -293,10 +288,9 @@ class DataLoader: """ search = [] - source = to_text(source, errors='surrogate_or_strict') # I have full path, nothing else needs to be looked at - if source.startswith(to_text(os.path.sep)) or source.startswith(u'~'): + if source.startswith(os.path.sep) or source.startswith('~'): search.append(unfrackpath(source, follow=False)) else: # base role/play path + templates/files/vars + relative filename @@ -363,7 +357,7 @@ class DataLoader: if os.path.exists(to_bytes(test_path, errors='surrogate_or_strict')): result = test_path else: - display.debug(u'evaluation_path:\n\t%s' % '\n\t'.join(paths)) + display.debug('evaluation_path:\n\t%s' % '\n\t'.join(paths)) for path in paths: upath = unfrackpath(path, follow=False) b_upath = to_bytes(upath, errors='surrogate_or_strict') @@ -384,9 +378,9 @@ class DataLoader: search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_dirname, b_source)) search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_source)) - display.debug(u'search_path:\n\t%s' % to_text(b'\n\t'.join(search))) + display.debug('search_path:\n\t%s' % to_text(b'\n\t'.join(search))) for b_candidate in search: - display.vvvvv(u'looking for "%s" at "%s"' % (source, to_text(b_candidate))) + display.vvvvv('looking for "%s" at "%s"' % (source, to_text(b_candidate))) if os.path.exists(b_candidate): result = to_text(b_candidate) break @@ -420,8 +414,7 @@ class DataLoader: if not file_path or not isinstance(file_path, (bytes, str)): raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_path)) - b_file_path = to_bytes(file_path, errors='surrogate_or_strict') - if not self.path_exists(b_file_path) or not self.is_file(b_file_path): + if not self.path_exists(file_path) or not self.is_file(file_path): raise AnsibleFileNotFound(file_name=file_path) real_path = self.path_dwim(file_path) @@ -479,7 +472,7 @@ class DataLoader: """ b_path = to_bytes(os.path.join(path, name)) - found = [] + found: list[str] = [] if extensions is None: # Look for file with no extension first to find dir before file @@ -488,27 +481,29 @@ class DataLoader: for ext in extensions: if '.' in ext: - full_path = b_path + to_bytes(ext) + b_full_path = b_path + to_bytes(ext) elif ext: - full_path = b'.'.join([b_path, to_bytes(ext)]) + b_full_path = b'.'.join([b_path, to_bytes(ext)]) else: - full_path = b_path + b_full_path = b_path + + full_path = to_text(b_full_path) if self.path_exists(full_path): if self.is_directory(full_path): if allow_dir: - found.extend(self._get_dir_vars_files(to_text(full_path), extensions)) + found.extend(self._get_dir_vars_files(full_path, extensions)) else: continue else: - found.append(to_text(full_path)) + found.append(full_path) break return found def _get_dir_vars_files(self, path: str, extensions: list[str]) -> list[str]: found = [] for spath in sorted(self.list_directory(path)): - if not spath.startswith(u'.') and not spath.endswith(u'~'): # skip hidden and backups + if not spath.startswith('.') and not spath.endswith('~'): # skip hidden and backups ext = os.path.splitext(spath)[-1] full_spath = os.path.join(path, spath) diff --git a/lib/ansible/plugins/inventory/toml.py b/lib/ansible/plugins/inventory/toml.py index c7a434659f1..eb38b5f9556 100644 --- a/lib/ansible/plugins/inventory/toml.py +++ b/lib/ansible/plugins/inventory/toml.py @@ -89,7 +89,6 @@ import tomllib from collections.abc import MutableMapping, MutableSequence from ansible.errors import AnsibleFileNotFound, AnsibleParserError -from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.plugins.inventory import BaseFileInventoryPlugin from ansible.utils.display import Display @@ -147,10 +146,9 @@ class InventoryModule(BaseFileInventoryPlugin): def _load_file(self, file_name): if not file_name or not isinstance(file_name, str): - raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_name)) + raise AnsibleParserError("Invalid filename: '%s'" % file_name) - b_file_name = to_bytes(self.loader.path_dwim(file_name)) - if not self.loader.path_exists(b_file_name): + if not self.loader.path_exists(file_name): raise AnsibleFileNotFound("Unable to retrieve file contents", file_name=file_name) try: diff --git a/test/units/module_utils/common/text/converters/test_to_str.py b/test/units/module_utils/common/text/converters/test_to_str.py index 4c2f63ae5ee..a06a91b72ee 100644 --- a/test/units/module_utils/common/text/converters/test_to_str.py +++ b/test/units/module_utils/common/text/converters/test_to_str.py @@ -45,3 +45,62 @@ def test_to_bytes(in_string, encoding, expected): def test_to_native(in_string, encoding, expected): """test happy path of encoding to native strings""" assert to_native(in_string, encoding) == expected + + +def test_type_hints() -> None: + """This test isn't really here to test the functionality of to_text/to_bytes + but more to ensure the overloads are properly validated for type hinting + """ + d: dict[str, str] = {'k': 'v'} + s: str = 's' + b: bytes = b'b' + + to_bytes_bytes: bytes = to_bytes(b) + to_bytes_str: bytes = to_bytes(s) + to_bytes_dict: bytes = to_bytes(d) + assert to_bytes_dict == repr(d).encode('utf-8') + + to_bytes_bytes_repr: bytes = to_bytes(b, nonstring='simplerepr') + to_bytes_str_repr: bytes = to_bytes(s, nonstring='simplerepr') + to_bytes_dict_repr: bytes = to_bytes(d, nonstring='simplerepr') + assert to_bytes_dict_repr == repr(d).encode('utf-8') + + to_bytes_bytes_passthru: bytes = to_bytes(b, nonstring='passthru') + to_bytes_str_passthru: bytes = to_bytes(s, nonstring='passthru') + to_bytes_dict_passthru: dict[str, str] = to_bytes(d, nonstring='passthru') + assert to_bytes_dict_passthru == d + + to_bytes_bytes_empty: bytes = to_bytes(b, nonstring='empty') + to_bytes_str_empty: bytes = to_bytes(s, nonstring='empty') + to_bytes_dict_empty: bytes = to_bytes(d, nonstring='empty') + assert to_bytes_dict_empty == b'' + + to_bytes_bytes_strict: bytes = to_bytes(b, nonstring='strict') + to_bytes_str_strict: bytes = to_bytes(s, nonstring='strict') + with pytest.raises(TypeError): + to_bytes_dict_strict: bytes = to_bytes(d, nonstring='strict') + + to_text_bytes: str = to_text(b) + to_text_str: str = to_text(s) + to_text_dict: str = to_text(d) + assert to_text_dict == repr(d) + + to_text_bytes_repr: str = to_text(b, nonstring='simplerepr') + to_text_str_repr: str = to_text(s, nonstring='simplerepr') + to_text_dict_repr: str = to_text(d, nonstring='simplerepr') + assert to_text_dict_repr == repr(d) + + to_text_bytes_passthru: str = to_text(b, nonstring='passthru') + to_text_str_passthru: str = to_text(s, nonstring='passthru') + to_text_dict_passthru: dict[str, str] = to_text(d, nonstring='passthru') + assert to_text_dict_passthru == d + + to_text_bytes_empty: str = to_text(b, nonstring='empty') + to_text_str_empty: str = to_text(s, nonstring='empty') + to_text_dict_empty: str = to_text(d, nonstring='empty') + assert to_text_dict_empty == '' + + to_text_bytes_strict: str = to_text(b, nonstring='strict') + to_text_str_strict: str = to_text(s, nonstring='strict') + with pytest.raises(TypeError): + to_text_dict_strict: str = to_text(d, nonstring='strict')