Add typing to to_text and to_bytes, improve typing and type juggling in DataLoader (#85746)

pull/85808/head
Matt Martz 3 months ago committed by GitHub
parent 4209d714db
commit c59db5349e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,2 @@
minor_changes:
- Python type hints applied to ``to_text`` and ``to_bytes`` functions for better type hint interactions with code utilizing these functions.

@ -339,12 +339,12 @@ def verify_local_collection(local_collection, remote_collection, artifacts_manag
] ]
# Find any paths not in the FILES.json # Find any paths not in the FILES.json
for root, dirs, files in os.walk(b_collection_path): for root, dirs, filenames in os.walk(b_collection_path):
for name in files: for name in filenames:
full_path = os.path.join(root, name) full_path = os.path.join(root, name)
path = to_text(full_path[len(b_collection_path) + 1::], errors='surrogate_or_strict') path = to_text(full_path[len(b_collection_path) + 1::], errors='surrogate_or_strict')
if any(fnmatch.fnmatch(full_path, b_pattern) for b_pattern in b_ignore_patterns): if any(fnmatch.fnmatch(full_path, b_pattern) for b_pattern in b_ignore_patterns):
display.v("Ignoring verification for %s" % full_path) display.v("Ignoring verification for %s" % to_text(full_path))
continue continue
if full_path not in collection_files: if full_path not in collection_files:

@ -8,9 +8,9 @@ from __future__ import annotations
import codecs import codecs
import json import json
from ansible.module_utils.compat import typing as _t
from ansible.module_utils._internal import _no_six from ansible.module_utils._internal import _no_six
try: try:
codecs.lookup_error('surrogateescape') codecs.lookup_error('surrogateescape')
HAS_SURROGATEESCAPE = True HAS_SURROGATEESCAPE = True
@ -22,8 +22,54 @@ _COMPOSED_ERROR_HANDLERS = frozenset((None, 'surrogate_or_replace',
'surrogate_or_strict', 'surrogate_or_strict',
'surrogate_then_replace')) 'surrogate_then_replace'))
_T = _t.TypeVar('_T')
_NonStringPassthru: _t.TypeAlias = _t.Literal['passthru']
_NonStringOther: _t.TypeAlias = _t.Literal['simplerepr', 'empty', 'strict']
_NonStringAll: _t.TypeAlias = _t.Union[_NonStringPassthru, _NonStringOther]
@_t.overload
def to_bytes(
obj: object,
encoding: str = 'utf-8',
errors: str | None = None,
) -> bytes: ...
@_t.overload
def to_bytes(
obj: bytes | str,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringPassthru = 'passthru',
) -> bytes: ...
@_t.overload
def to_bytes(
obj: _T,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringPassthru = 'passthru',
) -> _T: ...
@_t.overload
def to_bytes(
obj: object,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringOther = 'simplerepr',
) -> bytes: ...
def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
def to_bytes(
obj: _T,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringAll = 'simplerepr'
) -> _T | bytes:
"""Make sure that a string is a byte string """Make sure that a string is a byte string
:arg obj: An object to make sure is a byte string. In most cases this :arg obj: An object to make sure is a byte string. In most cases this
@ -81,7 +127,7 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
string is valid in the specified encoding. If it's important that the string is valid in the specified encoding. If it's important that the
byte string is in the specified encoding do:: byte string is in the specified encoding do::
encoded_string = to_bytes(to_text(input_string, 'latin-1'), 'utf-8') encoded_string = to_bytes(to_text(input_string, encoding='latin-1'), encoding='utf-8')
.. version_changed:: 2.3 .. version_changed:: 2.3
@ -126,21 +172,60 @@ def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
value = repr(obj) value = repr(obj)
except UnicodeError: except UnicodeError:
# Giving up # Giving up
return to_bytes('') return b''
elif nonstring == 'passthru': elif nonstring == 'passthru':
return obj return obj
elif nonstring == 'empty': elif nonstring == 'empty':
# python2.4 doesn't have b'' return b''
return to_bytes('')
elif nonstring == 'strict': elif nonstring == 'strict':
raise TypeError('obj must be a string type') raise TypeError('obj must be a string type')
else: else:
raise TypeError('Invalid value %s for to_bytes\' nonstring parameter' % nonstring) raise TypeError('Invalid value %s for to_bytes\' nonstring parameter' % nonstring)
return to_bytes(value, encoding, errors) return to_bytes(value, encoding=encoding, errors=errors)
@_t.overload
def to_text(
obj: object,
encoding: str = 'utf-8',
errors: str | None = None,
) -> str: ...
@_t.overload
def to_text(
obj: str | bytes,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringPassthru = 'passthru',
) -> str: ...
@_t.overload
def to_text(
obj: _T,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringPassthru = 'passthru',
) -> _T: ...
@_t.overload
def to_text(
obj: object,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringOther = 'simplerepr',
) -> str: ...
def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): def to_text(
obj: _T,
encoding: str = 'utf-8',
errors: str | None = None,
nonstring: _NonStringAll = 'simplerepr'
) -> _T | str:
"""Make sure that a string is a text string """Make sure that a string is a text string
:arg obj: An object to make sure is a text string. In most cases this :arg obj: An object to make sure is a text string. In most cases this
@ -218,17 +303,17 @@ def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'):
value = repr(obj) value = repr(obj)
except UnicodeError: except UnicodeError:
# Giving up # Giving up
return u'' return ''
elif nonstring == 'passthru': elif nonstring == 'passthru':
return obj return obj
elif nonstring == 'empty': elif nonstring == 'empty':
return u'' return ''
elif nonstring == 'strict': elif nonstring == 'strict':
raise TypeError('obj must be a string type') raise TypeError('obj must be a string type')
else: else:
raise TypeError('Invalid value %s for to_text\'s nonstring parameter' % nonstring) raise TypeError('Invalid value %s for to_text\'s nonstring parameter' % nonstring)
return to_text(value, encoding, errors) return to_text(value, encoding=encoding, errors=errors)
to_native = to_text to_native = to_text

@ -31,7 +31,7 @@ display = Display()
# Tries to determine if a path is inside a role, last dir must be 'tasks' # Tries to determine if a path is inside a role, last dir must be 'tasks'
# this is not perfect but people should really avoid 'tasks' dirs outside roles when using Ansible. # this is not perfect but people should really avoid 'tasks' dirs outside roles when using Ansible.
RE_TASKS = re.compile(u'(?:^|%s)+tasks%s?$' % (os.path.sep, os.path.sep)) RE_TASKS = re.compile('(?:^|%s)+tasks%s?$' % (os.path.sep, os.path.sep))
class DataLoader: class DataLoader:
@ -53,23 +53,22 @@ class DataLoader:
ds = dl.load_from_file('/path/to/file') ds = dl.load_from_file('/path/to/file')
""" """
def __init__(self): def __init__(self) -> None:
self._basedir = '.' self._basedir: str = '.'
# NOTE: not effective with forks as the main copy does not get updated. # NOTE: not effective with forks as the main copy does not get updated.
# avoids rereading files # avoids rereading files
self._FILE_CACHE = dict() self._FILE_CACHE: dict[str, object] = {}
# NOTE: not thread safe, also issues with forks not returning data to main proc # NOTE: not thread safe, also issues with forks not returning data to main proc
# so they need to be cleaned independently. See WorkerProcess for example. # so they need to be cleaned independently. See WorkerProcess for example.
# used to keep track of temp files for cleaning # used to keep track of temp files for cleaning
self._tempfiles = set() self._tempfiles: set[str] = set()
# initialize the vault stuff with an empty password # initialize the vault stuff with an empty password
# TODO: replace with a ref to something that can get the password # TODO: replace with a ref to something that can get the password
# a creds/auth provider # a creds/auth provider
self._vaults = {}
self._vault = VaultLib() self._vault = VaultLib()
self.set_vault_secrets(None) self.set_vault_secrets(None)
@ -229,23 +228,19 @@ class DataLoader:
def set_basedir(self, basedir: str) -> None: def set_basedir(self, basedir: str) -> None:
""" sets the base directory, used to find files when a relative path is given """ """ sets the base directory, used to find files when a relative path is given """
self._basedir = basedir
if basedir is not None:
self._basedir = to_text(basedir)
def path_dwim(self, given: str) -> str: def path_dwim(self, given: str) -> str:
""" """
make relative paths work like folks expect. make relative paths work like folks expect.
""" """
given = to_text(given, errors='surrogate_or_strict')
given = unquote(given) given = unquote(given)
if given.startswith(to_text(os.path.sep)) or given.startswith(u'~'): if given.startswith(os.path.sep) or given.startswith('~'):
path = given path = given
else: else:
basedir = to_text(self._basedir, errors='surrogate_or_strict') path = os.path.join(self._basedir, given)
path = os.path.join(basedir, given)
return unfrackpath(path, follow=False) return unfrackpath(path, follow=False)
@ -293,10 +288,9 @@ class DataLoader:
""" """
search = [] search = []
source = to_text(source, errors='surrogate_or_strict')
# I have full path, nothing else needs to be looked at # I have full path, nothing else needs to be looked at
if source.startswith(to_text(os.path.sep)) or source.startswith(u'~'): if source.startswith(os.path.sep) or source.startswith('~'):
search.append(unfrackpath(source, follow=False)) search.append(unfrackpath(source, follow=False))
else: else:
# base role/play path + templates/files/vars + relative filename # base role/play path + templates/files/vars + relative filename
@ -363,7 +357,7 @@ class DataLoader:
if os.path.exists(to_bytes(test_path, errors='surrogate_or_strict')): if os.path.exists(to_bytes(test_path, errors='surrogate_or_strict')):
result = test_path result = test_path
else: else:
display.debug(u'evaluation_path:\n\t%s' % '\n\t'.join(paths)) display.debug('evaluation_path:\n\t%s' % '\n\t'.join(paths))
for path in paths: for path in paths:
upath = unfrackpath(path, follow=False) upath = unfrackpath(path, follow=False)
b_upath = to_bytes(upath, errors='surrogate_or_strict') b_upath = to_bytes(upath, errors='surrogate_or_strict')
@ -384,9 +378,9 @@ class DataLoader:
search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_dirname, b_source)) search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_dirname, b_source))
search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_source)) search.append(os.path.join(to_bytes(self.get_basedir(), errors='surrogate_or_strict'), b_source))
display.debug(u'search_path:\n\t%s' % to_text(b'\n\t'.join(search))) display.debug('search_path:\n\t%s' % to_text(b'\n\t'.join(search)))
for b_candidate in search: for b_candidate in search:
display.vvvvv(u'looking for "%s" at "%s"' % (source, to_text(b_candidate))) display.vvvvv('looking for "%s" at "%s"' % (source, to_text(b_candidate)))
if os.path.exists(b_candidate): if os.path.exists(b_candidate):
result = to_text(b_candidate) result = to_text(b_candidate)
break break
@ -420,8 +414,7 @@ class DataLoader:
if not file_path or not isinstance(file_path, (bytes, str)): if not file_path or not isinstance(file_path, (bytes, str)):
raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_path)) raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_path))
b_file_path = to_bytes(file_path, errors='surrogate_or_strict') if not self.path_exists(file_path) or not self.is_file(file_path):
if not self.path_exists(b_file_path) or not self.is_file(b_file_path):
raise AnsibleFileNotFound(file_name=file_path) raise AnsibleFileNotFound(file_name=file_path)
real_path = self.path_dwim(file_path) real_path = self.path_dwim(file_path)
@ -479,7 +472,7 @@ class DataLoader:
""" """
b_path = to_bytes(os.path.join(path, name)) b_path = to_bytes(os.path.join(path, name))
found = [] found: list[str] = []
if extensions is None: if extensions is None:
# Look for file with no extension first to find dir before file # Look for file with no extension first to find dir before file
@ -488,27 +481,29 @@ class DataLoader:
for ext in extensions: for ext in extensions:
if '.' in ext: if '.' in ext:
full_path = b_path + to_bytes(ext) b_full_path = b_path + to_bytes(ext)
elif ext: elif ext:
full_path = b'.'.join([b_path, to_bytes(ext)]) b_full_path = b'.'.join([b_path, to_bytes(ext)])
else: else:
full_path = b_path b_full_path = b_path
full_path = to_text(b_full_path)
if self.path_exists(full_path): if self.path_exists(full_path):
if self.is_directory(full_path): if self.is_directory(full_path):
if allow_dir: if allow_dir:
found.extend(self._get_dir_vars_files(to_text(full_path), extensions)) found.extend(self._get_dir_vars_files(full_path, extensions))
else: else:
continue continue
else: else:
found.append(to_text(full_path)) found.append(full_path)
break break
return found return found
def _get_dir_vars_files(self, path: str, extensions: list[str]) -> list[str]: def _get_dir_vars_files(self, path: str, extensions: list[str]) -> list[str]:
found = [] found = []
for spath in sorted(self.list_directory(path)): for spath in sorted(self.list_directory(path)):
if not spath.startswith(u'.') and not spath.endswith(u'~'): # skip hidden and backups if not spath.startswith('.') and not spath.endswith('~'): # skip hidden and backups
ext = os.path.splitext(spath)[-1] ext = os.path.splitext(spath)[-1]
full_spath = os.path.join(path, spath) full_spath = os.path.join(path, spath)

@ -89,7 +89,6 @@ import tomllib
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from ansible.errors import AnsibleFileNotFound, AnsibleParserError from ansible.errors import AnsibleFileNotFound, AnsibleParserError
from ansible.module_utils.common.text.converters import to_bytes, to_native
from ansible.plugins.inventory import BaseFileInventoryPlugin from ansible.plugins.inventory import BaseFileInventoryPlugin
from ansible.utils.display import Display from ansible.utils.display import Display
@ -147,10 +146,9 @@ class InventoryModule(BaseFileInventoryPlugin):
def _load_file(self, file_name): def _load_file(self, file_name):
if not file_name or not isinstance(file_name, str): if not file_name or not isinstance(file_name, str):
raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_name)) raise AnsibleParserError("Invalid filename: '%s'" % file_name)
b_file_name = to_bytes(self.loader.path_dwim(file_name)) if not self.loader.path_exists(file_name):
if not self.loader.path_exists(b_file_name):
raise AnsibleFileNotFound("Unable to retrieve file contents", file_name=file_name) raise AnsibleFileNotFound("Unable to retrieve file contents", file_name=file_name)
try: try:

@ -45,3 +45,62 @@ def test_to_bytes(in_string, encoding, expected):
def test_to_native(in_string, encoding, expected): def test_to_native(in_string, encoding, expected):
"""test happy path of encoding to native strings""" """test happy path of encoding to native strings"""
assert to_native(in_string, encoding) == expected assert to_native(in_string, encoding) == expected
def test_type_hints() -> None:
"""This test isn't really here to test the functionality of to_text/to_bytes
but more to ensure the overloads are properly validated for type hinting
"""
d: dict[str, str] = {'k': 'v'}
s: str = 's'
b: bytes = b'b'
to_bytes_bytes: bytes = to_bytes(b)
to_bytes_str: bytes = to_bytes(s)
to_bytes_dict: bytes = to_bytes(d)
assert to_bytes_dict == repr(d).encode('utf-8')
to_bytes_bytes_repr: bytes = to_bytes(b, nonstring='simplerepr')
to_bytes_str_repr: bytes = to_bytes(s, nonstring='simplerepr')
to_bytes_dict_repr: bytes = to_bytes(d, nonstring='simplerepr')
assert to_bytes_dict_repr == repr(d).encode('utf-8')
to_bytes_bytes_passthru: bytes = to_bytes(b, nonstring='passthru')
to_bytes_str_passthru: bytes = to_bytes(s, nonstring='passthru')
to_bytes_dict_passthru: dict[str, str] = to_bytes(d, nonstring='passthru')
assert to_bytes_dict_passthru == d
to_bytes_bytes_empty: bytes = to_bytes(b, nonstring='empty')
to_bytes_str_empty: bytes = to_bytes(s, nonstring='empty')
to_bytes_dict_empty: bytes = to_bytes(d, nonstring='empty')
assert to_bytes_dict_empty == b''
to_bytes_bytes_strict: bytes = to_bytes(b, nonstring='strict')
to_bytes_str_strict: bytes = to_bytes(s, nonstring='strict')
with pytest.raises(TypeError):
to_bytes_dict_strict: bytes = to_bytes(d, nonstring='strict')
to_text_bytes: str = to_text(b)
to_text_str: str = to_text(s)
to_text_dict: str = to_text(d)
assert to_text_dict == repr(d)
to_text_bytes_repr: str = to_text(b, nonstring='simplerepr')
to_text_str_repr: str = to_text(s, nonstring='simplerepr')
to_text_dict_repr: str = to_text(d, nonstring='simplerepr')
assert to_text_dict_repr == repr(d)
to_text_bytes_passthru: str = to_text(b, nonstring='passthru')
to_text_str_passthru: str = to_text(s, nonstring='passthru')
to_text_dict_passthru: dict[str, str] = to_text(d, nonstring='passthru')
assert to_text_dict_passthru == d
to_text_bytes_empty: str = to_text(b, nonstring='empty')
to_text_str_empty: str = to_text(s, nonstring='empty')
to_text_dict_empty: str = to_text(d, nonstring='empty')
assert to_text_dict_empty == ''
to_text_bytes_strict: str = to_text(b, nonstring='strict')
to_text_str_strict: str = to_text(s, nonstring='strict')
with pytest.raises(TypeError):
to_text_dict_strict: str = to_text(d, nonstring='strict')

Loading…
Cancel
Save