diff --git a/changelogs/fragments/deprecate_api.yml b/changelogs/fragments/deprecate_api.yml deleted file mode 100644 index 41429413ec4..00000000000 --- a/changelogs/fragments/deprecate_api.yml +++ /dev/null @@ -1,3 +0,0 @@ ---- -deprecated_features: - - fact_cache - deprecate first_order_merge API (https://github.com/ansible/ansible/pull/84568). diff --git a/changelogs/fragments/fix-is-filter-is-test.yml b/changelogs/fragments/fix-is-filter-is-test.yml new file mode 100644 index 00000000000..e6563846537 --- /dev/null +++ b/changelogs/fragments/fix-is-filter-is-test.yml @@ -0,0 +1,3 @@ +bugfixes: + - Correctly return ``False`` when using the ``filter`` and ``test`` Jinja tests on plugin names which are not filters or tests, respectively. + (resolves issue https://github.com/ansible/ansible/issues/82084) diff --git a/changelogs/fragments/templates_types_datatagging.yml b/changelogs/fragments/templates_types_datatagging.yml new file mode 100644 index 00000000000..6a5e565bb71 --- /dev/null +++ b/changelogs/fragments/templates_types_datatagging.yml @@ -0,0 +1,179 @@ +# DTFIX-RELEASE: document EncryptedString replacing AnsibleVaultEncryptedUnicode + +major_changes: + - variables - The type system underlying Ansible's variable storage has been significantly overhauled and formalized. + Attempts to store unsupported Python object types in variables will now result in an error. # DTFIX-RELEASE: link to type system docs TBD + - variables - To support new Ansible features, many variable objects are now represented by subclasses of their respective native Python types. + In most cases, they behave indistinguishably from their original types, but some Python libraries do not handle builtin object subclasses properly. + Custom plugins that interact with such libraries may require changes to convert and pass the native types. # DTFIX-RELEASE: link to plugin/data tagging API docs TBD + - ansible-test - Packages beneath ``module_utils`` can now contain ``__init__.py`` files. + - Jinja plugins - Jinja builtin filter and test plugins are now accessible via their fully-qualified names ``ansible.builtin.{name}``. + +minor_changes: + - templating - Templating errors now provide more information about both the location and context of the error, especially for deeply-nested and/or indirected templating scenarios. + - templating - Handling of omitted values is now a first-class feature of the template engine, and is usable in all Ansible Jinja template contexts. + Any template that resolves to ``omit`` is automatically removed from its parent container during templating. # DTFIX-RELEASE: porting guide entry + - templating - Unified ``omit`` behavior now requires that plugins calling ``Templar.template()`` handle cases where the entire template result is omitted, + by catching the ``AnsibleValueOmittedError`` that is raised. + Previously, this condition caused a randomly-generated string marker to appear in the template result. # DTFIX-RELEASE: porting guide entry? + - templating - Template evaluation is lazier than in previous versions. + Template expressions which resolve only portions of a data structure no longer result in the entire structure being templated. + - handlers - Templated handler names with syntax errors, or that resolve to ``omit`` are now skipped like handlers with undefined variables in their name. + - env lookup - The error message generated for a missing environment variable when ``default`` is an undefined value (e.g. ``undef('something')``) will contain the hint from that undefined value, + except when the undefined value is the default of ``undef()`` with no arguments. Previously, any existing undefined hint would be ignored. + - templating - Embedding ``range()`` values in containers such as lists will result in an error on use. + Previously the value would be converted to a string representing the range parameters, such as ``range(0, 3)``. + - Jinja plugins - Plugins can declare support for undefined values. # DTFIX-RELEASE: examples, porting guide entry + - templating - Variables of type ``set`` and ``tuple`` are now converted to ``list`` when exiting the final pass of templating. + - templating - Access to an undefined variable from inside a lookup, filter, or test (which raises MarkerError) no longer ends processing of the current template. + The triggering undefined value is returned as the result of the offending plugin invocation, and the template continues to execute. # DTFIX-RELEASE: porting guide entry, samples needed + - plugin error handling - When raising exceptions in an exception handler, be sure to use ``raise ... from`` as appropriate. + This supersedes the use of the ``AnsibleError`` arg ``orig_exc`` to represent the cause. + Specifying ``orig_exc`` as the cause is still permitted. + Failure to use ``raise ... from`` when ``orig_exc`` is set will result in a warning. + Additionally, if the two cause exceptions do not match, a warning will be issued. # DTFIX-RELEASE: this needs a porting guide entry + - ansible-test - The ``yamllint`` sanity test now enforces string values for the ``!vault`` tag. + - warnings - All warnings (including deprecation warnings) issued during a task's execution are now accessible via the ``warnings`` and ``deprecations`` keys on the task result. + - troubleshooting - Tracebacks can be collected and displayed for most errors, warnings, and deprecation warnings (including those generated by modules). + Tracebacks are no longer enabled with ``-vvv``; the behavior is directly configurable via the ``DISPLAY_TRACEBACK`` config option. + Module tracebacks passed to ``fail_json`` via the ``exception`` kwarg will not be included in the task result unless error tracebacks are configured. + - display - Deduplication of warning and error messages considers the full content of the message (including source and traceback contexts, if enabled). + This may result in fewer messages being omitted. + - modules - Unhandled exceptions during Python module execution are now returned as structured data from the target. + This allows the new traceback handling to be applied to exceptions raised on targets. + - modules - PowerShell modules can now receive ``datetime.date``, ``datetime.time`` and ``datetime.datetime`` values as ISO 8601 strings. + - modules - PowerShell modules can now receive strings sourced from inline vault-encrypted strings. + - from_json filter - The filter accepts a ``profile`` argument, which defaults to ``tagless``. + - to_json / to_nice_json filters - The filters accept a ``profile`` argument, which defaults to ``tagless``. + - undef jinja function - The ``undef`` jinja function now raises an error if a non-string hint is given. + Attempting to use an undefined hint also results in an error, ensuring incorrect use of the function can be distinguished from the function's normal behavior. + - display - The ``collection_name`` arg to ``Display.deprecated`` no longer has any effect. + Information about the calling plugin is automatically captured by the display infrastructure, included in the displayed messages, and made available to callbacks. + - modules - The ``collection_name`` arg to Python module-side ``deprecate`` methods no longer has any effect. + Information about the calling module is automatically captured by the warning infrastructure and included in the module result. + +breaking_changes: + - loops - Omit placeholders no longer leak between loop item templating and task templating. + Previously, ``omit`` placeholders could remain embedded in loop items after templating and be used as an ``omit`` for task templating. + Now, values resolving to ``omit`` are dropped immediately when loop items are templated. + To turn missing values into an ``omit`` for task templating, use ``| default(omit)``. + This solution is backwards compatible with previous versions of ansible-core. # DTFIX-RELEASE: porting guide entry with examples + - serialization of ``omit`` sentinel - Serialization of variables containing ``omit`` sentinels (e.g., by the ``to_json`` and ``to_yaml`` filters or ``ansible-inventory``) will fail if the variable has not completed templating. + Previously, serialization succeeded with placeholder strings emitted in the serialized output. + - conditionals - Conditional expressions that result in non-boolean values are now an error by default. + Such results often indicate unintentional use of templates where they are not supported, resulting in a conditional that is always true. + When this option is enabled, conditional expressions which are a literal ``None`` or empty string will evaluate as true, for backwards compatibility. + The error can be temporarily changed to a deprecation warning by enabling the ``ALLOW_BROKEN_CONDITIONALS`` config option. + - templating - Templates are always rendered in Jinja2 native mode. + As a result, non-string values are no longer automatically converted to strings. + - templating - Templates with embedded inline templates that were not contained within a Jinja string constant now result in an error, as support for multi-pass templating was removed for security reasons. + In most cases, such templates can be easily rewritten to avoid the use of embedded inline templates. + - templating - Conditionals and lookups which use embedded inline templates in Jinja string constants now display a warning. + These templates should be converted to their expression equivalent. + - templating - Templates resulting in ``None`` are no longer automatically converted to an empty string. + - template lookup - The ``convert_data`` option is deprecated and no longer has any effect. + Use the ``from_json`` filter on the lookup result instead. + - templating - ``#jinja2:`` overrides in templates with invalid override names or types are now templating errors. + - set_fact - The string values "yes", "no", "true" and "false" were previously converted (ignoring case) to boolean values when not using Jinja2 native mode. + Since Jinja2 native mode is always used, this conversion no longer occurs. + When boolean values are required, native boolean syntax should be used where variables are defined, such as in YAML. + When native boolean syntax is not an option, the ``bool`` filter can be used to parse string values into booleans. + - templating - The ``allow_unsafe_lookups`` option no longer has any effect. + Lookup plugins are responsible for tagging strings containing templates to allow evaluation as a template. + - assert - The ``quiet`` argument must be a commonly-accepted boolean value. + Previously, unrecognized values were silently treated as False. + - plugins - Any plugin that sources or creates templates must properly tag them as trusted. # DTFIX-RELEASE: porting guide entry for "how?" Don't forget to mention inventory plugin ``trusted_by_default`` config. + - first_found lookup - When specifying ``files`` or ``paths`` as a templated list containing undefined values, the undefined list elements will be discarded with a warning. + Previously, the entire list would be discarded without any warning. + - templating - The result of the ``range()`` global function cannot be returned from a template- it should always be passed to a filter (e.g., ``random``). + Previously, range objects returned from an intermediate template were always converted to a list, which is inconsistent with inline consumption of range objects. + - plugins - Custom Jinja plugins that accept undefined top-level arguments must opt in to receiving them. # DTFIX-RELEASE: porting guide entry + backcompat behavior description + - plugins - Custom Jinja plugins that use ``environment.getitem`` to retrieve undefined values will now trigger a ``MarkerError`` exception. + This exception must be handled to allow the plugin to return a ``Marker``, or the plugin must opt-in to accepting ``Marker`` values. # DTFIX-RELEASE: mention the decorator + - templating - Many Jinja plugins (filters, lookups, tests) and methods previously silently ignored undefined inputs, which often masked subtle errors. + Passing an undefined argument to a Jinja plugin or method that does not declare undefined support now results in an undefined value. # DTFIX-RELEASE: common examples, porting guide, `is defined`, `is undefined`, etc; porting guide should also mention that overly-broad exception handling may mask Undefined errors; also that lazy handling of Undefined can invoke a plugin and bomb out in the middle where it was previously never invoked (plugins with side effects, just don't) + - lookup plugins - Lookup plugins called as `with_(lookup)` will no longer have the `_subdir` attribute set. # DTFIX-RELEASE: porting guide re: `ansible_lookup_context` + - lookup plugins - ``terms`` will always be passed to ``run`` as the first positional arg, where previously it was sometimes passed as a keyword arg when using ``with_`` syntax. + - callback plugins - The structure of the ``exception``, ``warnings`` and ``deprecations`` values visible to callbacks has changed. Callbacks that inspect or serialize these values may require special handling. # DTFIX-RELEASE: porting guide re ErrorDetail/WarningMessageDetail/DeprecationMessageDetail + - modules - Ansible modules using ``sys.excepthook`` must use a standard ``try/except`` instead. + - templating - Access to ``_`` prefixed attributes and methods, and methods with known side effects, is no longer permitted. + In cases where a matching mapping key is present, the associated value will be returned instead of an error. + This increases template environment isolation and ensures more consistent behavior between the ``.`` and ``[]`` operators. + - inventory - Invalid variable names provided by inventories result in an inventory parse failure. This behavior is now consistent with other variable name usages throughout Ansible. + - internals - The ``ansible.utils.native_jinja`` Python module has been removed. + - internals - The ``AnsibleLoader`` and ``AnsibleDumper`` classes for working with YAML are now factory functions and cannot be extended. + - public API - The ``ansible.vars.fact_cache.FactCache`` wrapper has been removed. + +security_fixes: + - templating - Ansible's template engine no longer processes Jinja templates in strings unless they are marked as coming from a trusted source. + Untrusted strings containing Jinja template markers are ignored with a warning. + Examples of trusted sources include playbooks, vars files, and many inventory sources. + Examples of untrusted sources include module results and facts. + Plugins which have not been updated to preserve trust while manipulating strings may inadvertently cause them to lose their trusted status. + - templating - Changes to conditional expression handling removed numerous instances of insecure multi-pass templating (which could result in execution of untrusted template expressions). + +known_issues: + - variables - The values ``None``, ``True`` and ``False`` cannot be tagged because they are singletons. Attempts to apply tags to these values will be silently ignored. + - variables - Tagged values cannot be used for dictionary keys in many circumstances. # DTFIX-RELEASE: Explain this in more detail. + - templating - Any string value starting with ``#jinja2:`` which is templated will always be interpreted as Jinja2 configuration overrides. + To include this literal value at the start of a string, a space or other character must precede it. + +bugfixes: + - module defaults - Module defaults are no longer templated unless they are used by a task that does not override them. + Previously, all module defaults for all modules were templated for every task. + - omitting task args - Use of omit for task args now properly falls back to args of lower precedence, such as module defaults. + Previously an omitted value would obliterate values of lower precedence. # DTFIX-RELEASE: do we need obliterate, is this a breaking change? + - regex_search filter - Corrected return value documentation to reflect None (not empty string) for no match. + - first_found lookup - Corrected return value documentation to reflect None (not empty string) for no files found. + - vars lookup - The ``default`` substitution only applies when trying to look up a variable which is not defined. + If the variable is defined, but templates to an undefined value, the ``default`` substitution will not apply. + Use the ``default`` filter to coerce those values instead. + - to_yaml/to_nice_yaml filters - Eliminated possibility of keyword arg collisions with internally-set defaults. + - Jinja plugins - Errors raised will always be derived from ``AnsibleTemplatePluginError``. + - ansible-test - Fixed traceback when handling certain YAML errors in the ``yamllint`` sanity test. + - YAML parsing - The `!unsafe` tag no longer coerces non-string scalars to strings. + - default callback - Error context is now shown for failing tasks that use the ``debug`` action. + - module arg templating - When using a templated raw task arg and a templated ``args`` keyword, args are now merged. + Previously use of templated raw task args silently ignored all values from the templated ``args`` keyword. + - action plugins - Action plugins that raise unhandled exceptions no longer terminate playbook loops. Previously, exceptions raised by an action plugin caused abnormal loop termination and loss of loop iteration results. + - display - The ``Display.deprecated`` method once again properly handles the ``removed=True`` argument (https://github.com/ansible/ansible/issues/82358). + - stability - Fixed silent process failure on unhandled IOError/OSError under ``linear`` strategy. + - lookup plugins - The ``terms`` arg to the ``run`` method is now always a list. + Previously, there were cases where a non-list could be received. + +deprecated_features: + - templating - The ``ansible_managed`` variable available for certain templating scenarios, such as the ``template`` action and ``template`` lookup has been deprecated. + Define and use a custom variable instead of relying on ``ansible_managed``. + - display - The ``Display.get_deprecation_message`` method has been deprecated. + Call ``Display.deprecated`` to display a deprecation message, or call it with ``removed=True`` to raise an ``AnsibleError``. + - config - The ``DEFAULT_JINJA2_NATIVE`` option has no effect. + Jinja2 native mode is now the default and only option. + - config - The ``DEFAULT_NULL_REPRESENTATION`` option has no effect. + Null values are no longer automatically converted to another value during templating of single variable references. + - template lookup - The jinja2_native option is no longer used in the Ansible Core code base. + Jinja2 native mode is now the default and only option. + - conditionals - Conditionals using Jinja templating delimiters (e.g., ``{{``, ``{%``) should be rewritten as expressions without delimiters, unless the entire conditional value is a single template that resolves to a trusted string expression. + This is useful for dynamic indirection of conditional expressions, but is limited to trusted literal string expressions. + - templating - The ``disable_lookups`` option has no effect, since plugins must be updated to apply trust before any templating can be performed. + - to_yaml/to_nice_yaml filters - Implicit YAML dumping of vaulted value ciphertext is deprecated. + Set `dump_vault_tags` to explicitly specify the desired behavior. + - plugins - The ``listify_lookup_plugin_terms`` function is obsolete and in most cases no longer needed. # DTFIX-RELEASE: add a porting guide entry for this + - plugin error handling - The ``AnsibleError`` constructor arg ``suppress_extended_error`` is deprecated. + Using ``suppress_extended_error=True`` has the same effect as ``show_content=False``. + - config - The ``ACTION_WARNINGS`` config has no effect. It previously disabled command warnings, which have since been removed. + - templating - Support for enabling Jinja2 extensions (not plugins) has been deprecated. + - playbook variables - The ``play_hosts`` variable has been deprecated, use ``ansible_play_batch`` instead. + - bool filter - Support for coercing unrecognized input values (including None) has been deprecated. Consult the filter documentation for acceptable values, or consider use of the ``truthy`` and ``falsy`` tests. # DTFIX-RELEASE: porting guide + - oneline callback - The ``oneline`` callback and its associated ad-hoc CLI args (``-o``, ``--one-line``) are deprecated. + - tree callback - The ``tree`` callback and its associated ad-hoc CLI args (``-t``, ``--tree``) are deprecated. + - CLI - The ``--inventory-file`` option alias is deprecated. Use the ``-i`` or ``--inventory`` option instead. + - first_found lookup - Splitting of file paths on ``,;:`` is deprecated. Pass a list of paths instead. + The ``split`` method on strings can be used to split variables into a list as needed. + - cache plugins - The `ansible.plugins.cache.base` Python module is deprecated. Use `ansible.plugins.cache` instead. + - file loading - Loading text files with ``DataLoader`` containing data that cannot be decoded under the expected encoding is deprecated. + In most cases the encoding must be UTF-8, although some plugins allow choosing a different encoding. + Previously, invalid data was silently wrapped in Unicode surrogate escape sequences, often resulting in later errors or other data corruption. + +removed_features: + - modules - Modules returning non-UTF8 strings now result in an error. + The ``MODULE_STRICT_UTF8_RESPONSE`` setting can be used to disable this check. diff --git a/changelogs/fragments/toml-library-support-dropped.yml b/changelogs/fragments/toml-library-support-dropped.yml new file mode 100644 index 00000000000..e31ec432699 --- /dev/null +++ b/changelogs/fragments/toml-library-support-dropped.yml @@ -0,0 +1,4 @@ +breaking_changes: + - Support for the ``toml`` library has been removed from TOML inventory parsing and dumping. + Use ``tomli`` for parsing on Python 3.10. Python 3.11 and later have built-in support for parsing. + Use ``tomli-w`` to support outputting inventory in TOML format. diff --git a/hacking/test-module.py b/hacking/test-module.py index a9df1a79b8f..ca0e1ab425d 100755 --- a/hacking/test-module.py +++ b/hacking/test-module.py @@ -40,10 +40,10 @@ import shutil from pathlib import Path +from ansible.module_utils.common.messages import PluginInfo from ansible.release import __version__ import ansible.utils.vars as utils_vars from ansible.parsing.dataloader import DataLoader -from ansible.parsing.utils.jsonify import jsonify from ansible.parsing.splitter import parse_kv from ansible.plugins.loader import init_plugin_loader from ansible.executor import module_common @@ -89,6 +89,22 @@ def parse(): return options, args +def jsonify(result, format=False): + """ format JSON output (uncompressed or uncompressed) """ + + if result is None: + return "{}" + + indent = None + if format: + indent = 4 + + try: + return json.dumps(result, sort_keys=True, indent=indent, ensure_ascii=False) + except UnicodeDecodeError: + return json.dumps(result, sort_keys=True, indent=indent) + + def write_argsfile(argstring, json=False): """ Write args to a file for old-style module's use. """ argspath = Path("~/.ansible_test_module_arguments").expanduser() @@ -152,16 +168,27 @@ def boilerplate_module(modfile, args, interpreters, check, destfile): if check: complex_args['_ansible_check_mode'] = True + modfile = os.path.abspath(modfile) modname = os.path.basename(modfile) modname = os.path.splitext(modname)[0] - (module_data, module_style, shebang) = module_common.modify_module( - modname, - modfile, - complex_args, - Templar(loader=loader), + + plugin = PluginInfo( + requested_name=modname, + resolved_name=modname, + type='module', + ) + + built_module = module_common.modify_module( + module_name=modname, + plugin=plugin, + module_path=modfile, + module_args=complex_args, + templar=Templar(loader=loader), task_vars=task_vars ) + module_data, module_style = built_module.b_module_data, built_module.module_style + if module_style == 'new' and '_ANSIBALLZ_WRAPPER = True' in to_native(module_data): module_style = 'ansiballz' diff --git a/lib/ansible/_internal/__init__.py b/lib/ansible/_internal/__init__.py new file mode 100644 index 00000000000..eaf75bb1069 --- /dev/null +++ b/lib/ansible/_internal/__init__.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import importlib +import typing as t + +from ansible.module_utils import _internal +from ansible.module_utils._internal._json import _profiles + + +def get_controller_serialize_map() -> dict[type, t.Callable]: + """ + Injected into module_utils code to augment serialization maps with controller-only types. + This implementation replaces the no-op version in module_utils._internal in controller contexts. + """ + from ansible._internal._templating import _lazy_containers + from ansible.parsing.vault import EncryptedString + + return { + _lazy_containers._AnsibleLazyTemplateDict: _profiles._JSONSerializationProfile.discard_tags, + _lazy_containers._AnsibleLazyTemplateList: _profiles._JSONSerializationProfile.discard_tags, + EncryptedString: str, # preserves tags since this is an intance of EncryptedString; if tags should be discarded from str, another entry will handle it + } + + +def import_controller_module(module_name: str, /) -> t.Any: + """ + Injected into module_utils code to import and return the specified module. + This implementation replaces the no-op version in module_utils._internal in controller contexts. + """ + return importlib.import_module(module_name) + + +_T = t.TypeVar('_T') + + +def experimental(obj: _T) -> _T: + """ + Decorator for experimental types and methods outside the `_internal` package which accept or expose internal types. + As with internal APIs, these are subject to change at any time without notice. + """ + return obj + + +def setup() -> None: + """No-op function to ensure that side-effect only imports of this module are not flagged/removed as 'unused'.""" + + +# DTFIX-RELEASE: this is really fragile- disordered/incorrect imports (among other things) can mess it up. Consider a hosting-env-managed context +# with an enum with at least Controller/Target/Unknown values, and possibly using lazy-init module shims or some other mechanism to allow controller-side +# notification/augmentation of this kind of metadata. +_internal.get_controller_serialize_map = get_controller_serialize_map +_internal.import_controller_module = import_controller_module +_internal.is_controller = True diff --git a/lib/ansible/_internal/_ansiballz.py b/lib/ansible/_internal/_ansiballz.py new file mode 100644 index 00000000000..b60d02de1b1 --- /dev/null +++ b/lib/ansible/_internal/_ansiballz.py @@ -0,0 +1,265 @@ +# shebang placeholder + +from __future__ import annotations + +import datetime + +# For test-module.py script to tell this is a ANSIBALLZ_WRAPPER +_ANSIBALLZ_WRAPPER = True + +# This code is part of Ansible, but is an independent component. +# The code in this particular templatable string, and this templatable string +# only, is BSD licensed. Modules which end up using this snippet, which is +# dynamically combined together by Ansible still belong to the author of the +# module, and they may assign their own license to the complete work. +# +# Copyright (c), James Cammarata, 2016 +# Copyright (c), Toshio Kuratomi, 2016 +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +def _ansiballz_main( + zipdata: str, + ansible_module: str, + module_fqn: str, + params: str, + profile: str, + plugin_info_dict: dict[str, object], + date_time: datetime.datetime, + coverage_config: str | None, + coverage_output: str | None, + rlimit_nofile: int, +) -> None: + import os + import os.path + + # Access to the working directory is required by Python when using pipelining, as well as for the coverage module. + # Some platforms, such as macOS, may not allow querying the working directory when using become to drop privileges. + try: + os.getcwd() + except OSError: + try: + os.chdir(os.path.expanduser('~')) + except OSError: + os.chdir('/') + + if rlimit_nofile: + import resource + + existing_soft, existing_hard = resource.getrlimit(resource.RLIMIT_NOFILE) + + # adjust soft limit subject to existing hard limit + requested_soft = min(existing_hard, rlimit_nofile) + + if requested_soft != existing_soft: + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (requested_soft, existing_hard)) + except ValueError: + # some platforms (eg macOS) lie about their hard limit + pass + + import sys + import __main__ + + # For some distros and python versions we pick up this script in the temporary + # directory. This leads to problems when the ansible module masks a python + # library that another import needs. We have not figured out what about the + # specific distros and python versions causes this to behave differently. + # + # Tested distros: + # Fedora23 with python3.4 Works + # Ubuntu15.10 with python2.7 Works + # Ubuntu15.10 with python3.4 Fails without this + # Ubuntu16.04.1 with python3.5 Fails without this + # To test on another platform: + # * use the copy module (since this shadows the stdlib copy module) + # * Turn off pipelining + # * Make sure that the destination file does not exist + # * ansible ubuntu16-test -m copy -a 'src=/etc/motd dest=/var/tmp/m' + # This will traceback in shutil. Looking at the complete traceback will show + # that shutil is importing copy which finds the ansible module instead of the + # stdlib module + scriptdir = None + try: + scriptdir = os.path.dirname(os.path.realpath(__main__.__file__)) + except (AttributeError, OSError): + # Some platforms don't set __file__ when reading from stdin + # OSX raises OSError if using abspath() in a directory we don't have + # permission to read (realpath calls abspath) + pass + + # Strip cwd from sys.path to avoid potential permissions issues + excludes = {'', '.', scriptdir} + sys.path = [p for p in sys.path if p not in excludes] + + import base64 + import shutil + import tempfile + import zipfile + + def invoke_module(modlib_path: str, json_params: bytes) -> None: + # When installed via setuptools (including python setup.py install), + # ansible may be installed with an easy-install.pth file. That file + # may load the system-wide install of ansible rather than the one in + # the module. sitecustomize is the only way to override that setting. + z = zipfile.ZipFile(modlib_path, mode='a') + + # py3: modlib_path will be text, py2: it's bytes. Need bytes at the end + sitecustomize = u'import sys\\nsys.path.insert(0,"%s")\\n' % modlib_path + sitecustomize = sitecustomize.encode('utf-8') + # Use a ZipInfo to work around zipfile limitation on hosts with + # clocks set to a pre-1980 year (for instance, Raspberry Pi) + zinfo = zipfile.ZipInfo() + zinfo.filename = 'sitecustomize.py' + zinfo.date_time = date_time.utctimetuple()[:6] + z.writestr(zinfo, sitecustomize) + z.close() + + # Put the zipped up module_utils we got from the controller first in the python path so that we + # can monkeypatch the right basic + sys.path.insert(0, modlib_path) + + from ansible.module_utils._internal._ansiballz import run_module + + run_module( + json_params=json_params, + profile=profile, + plugin_info_dict=plugin_info_dict, + module_fqn=module_fqn, + modlib_path=modlib_path, + coverage_config=coverage_config, + coverage_output=coverage_output, + ) + + def debug(command: str, modlib_path: str, json_params: bytes) -> None: + # The code here normally doesn't run. It's only used for debugging on the + # remote machine. + # + # The subcommands in this function make it easier to debug ansiballz + # modules. Here's the basic steps: + # + # Run ansible with the environment variable: ANSIBLE_KEEP_REMOTE_FILES=1 and -vvv + # to save the module file remotely:: + # $ ANSIBLE_KEEP_REMOTE_FILES=1 ansible host1 -m ping -a 'data=october' -vvv + # + # Part of the verbose output will tell you where on the remote machine the + # module was written to:: + # [...] + # SSH: EXEC ssh -C -q -o ControlMaster=auto -o ControlPersist=60s -o KbdInteractiveAuthentication=no -o + # PreferredAuthentications=gssapi-with-mic,gssapi-keyex,hostbased,publickey -o PasswordAuthentication=no -o ConnectTimeout=10 -o + # ControlPath=/home/badger/.ansible/cp/ansible-ssh-%h-%p-%r -tt rhel7 '/bin/sh -c '"'"'LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8 + # LC_MESSAGES=en_US.UTF-8 /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping'"'"'' + # [...] + # + # Login to the remote machine and run the module file via from the previous + # step with the explode subcommand to extract the module payload into + # source files:: + # $ ssh host1 + # $ /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping explode + # Module expanded into: + # /home/badger/.ansible/tmp/ansible-tmp-1461173408.08-279692652635227/ansible + # + # You can now edit the source files to instrument the code or experiment with + # different parameter values. When you're ready to run the code you've modified + # (instead of the code from the actual zipped module), use the execute subcommand like this:: + # $ /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping execute + + # Okay to use __file__ here because we're running from a kept file + basedir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'debug_dir') + args_path = os.path.join(basedir, 'args') + + if command == 'explode': + # transform the ZIPDATA into an exploded directory of code and then + # print the path to the code. This is an easy way for people to look + # at the code on the remote machine for debugging it in that + # environment + z = zipfile.ZipFile(modlib_path) + for filename in z.namelist(): + if filename.startswith('/'): + raise Exception('Something wrong with this module zip file: should not contain absolute paths') + + dest_filename = os.path.join(basedir, filename) + if dest_filename.endswith(os.path.sep) and not os.path.exists(dest_filename): + os.makedirs(dest_filename) + else: + directory = os.path.dirname(dest_filename) + if not os.path.exists(directory): + os.makedirs(directory) + with open(dest_filename, 'wb') as writer: + writer.write(z.read(filename)) + + # write the args file + with open(args_path, 'wb') as writer: + writer.write(json_params) + + print('Module expanded into:') + print(basedir) + + elif command == 'execute': + # Execute the exploded code instead of executing the module from the + # embedded ZIPDATA. This allows people to easily run their modified + # code on the remote machine to see how changes will affect it. + + # Set pythonpath to the debug dir + sys.path.insert(0, basedir) + + # read in the args file which the user may have modified + with open(args_path, 'rb') as reader: + json_params = reader.read() + + from ansible.module_utils._internal._ansiballz import run_module + + run_module( + json_params=json_params, + profile=profile, + plugin_info_dict=plugin_info_dict, + module_fqn=module_fqn, + modlib_path=modlib_path, + ) + + else: + print('WARNING: Unknown debug command. Doing nothing.') + + # + # See comments in the debug() method for information on debugging + # + + encoded_params = params.encode() + + # There's a race condition with the controller removing the + # remote_tmpdir and this module executing under async. So we cannot + # store this in remote_tmpdir (use system tempdir instead) + # Only need to use [ansible_module]_payload_ in the temp_path until we move to zipimport + # (this helps ansible-test produce coverage stats) + temp_path = tempfile.mkdtemp(prefix='ansible_' + ansible_module + '_payload_') + + try: + zipped_mod = os.path.join(temp_path, 'ansible_' + ansible_module + '_payload.zip') + + with open(zipped_mod, 'wb') as modlib: + modlib.write(base64.b64decode(zipdata)) + + if len(sys.argv) == 2: + debug(sys.argv[1], zipped_mod, encoded_params) + else: + invoke_module(zipped_mod, encoded_params) + finally: + shutil.rmtree(temp_path, ignore_errors=True) diff --git a/lib/ansible/_internal/_datatag/__init__.py b/lib/ansible/_internal/_datatag/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/_internal/_datatag/_tags.py b/lib/ansible/_internal/_datatag/_tags.py new file mode 100644 index 00000000000..e8e39f28328 --- /dev/null +++ b/lib/ansible/_internal/_datatag/_tags.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import dataclasses +import os +import types +import typing as t + +from ansible.module_utils._internal._datatag import _tag_dataclass_kwargs, AnsibleDatatagBase, AnsibleSingletonTagBase + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class Origin(AnsibleDatatagBase): + """ + A tag that stores origin metadata for a tagged value, intended for forensic/diagnostic use. + Origin metadata should not be used to make runtime decisions, as it is not guaranteed to be present or accurate. + Setting both `path` and `line_num` can result in diagnostic display of referenced file contents. + Either `path` or `description` must be present. + """ + + path: str | None = None + """The path from which the tagged content originated.""" + description: str | None = None + """A description of the origin, for display to users.""" + line_num: int | None = None + """An optional line number, starting at 1.""" + col_num: int | None = None + """An optional column number, starting at 1.""" + + UNKNOWN: t.ClassVar[t.Self] + + @classmethod + def get_or_create_tag(cls, value: t.Any, path: str | os.PathLike | None) -> Origin: + """Return the tag from the given value, creating a tag from the provided path if no tag was found.""" + if not (origin := cls.get_tag(value)): + if path: + origin = Origin(path=str(path)) # convert tagged strings and path-like values to a native str + else: + origin = Origin.UNKNOWN + + return origin + + def replace( + self, + path: str | types.EllipsisType = ..., + description: str | types.EllipsisType = ..., + line_num: int | None | types.EllipsisType = ..., + col_num: int | None | types.EllipsisType = ..., + ) -> t.Self: + """Return a new origin based on an existing one, with the given fields replaced.""" + return dataclasses.replace( + self, + **{ + key: value + for key, value in dict( + path=path, + description=description, + line_num=line_num, + col_num=col_num, + ).items() + if value is not ... + }, # type: ignore[arg-type] + ) + + def _post_validate(self) -> None: + if self.path: + if not self.path.startswith('/'): + raise RuntimeError('The `src` field must be an absolute path.') + elif not self.description: + raise RuntimeError('The `src` or `description` field must be specified.') + + def __str__(self) -> str: + """Renders the origin in the form of path:line_num:col_num, omitting missing/invalid elements from the right.""" + if self.path: + value = self.path + else: + value = self.description + + if self.line_num and self.line_num > 0: + value += f':{self.line_num}' + + if self.col_num and self.col_num > 0: + value += f':{self.col_num}' + + if self.path and self.description: + value += f' ({self.description})' + + return value + + +Origin.UNKNOWN = Origin(description='') + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class VaultedValue(AnsibleDatatagBase): + """Tag for vault-encrypted strings that carries the original ciphertext for round-tripping.""" + + ciphertext: str + + def _get_tag_to_propagate(self, src: t.Any, value: object, *, value_type: t.Optional[type] = None) -> t.Self | None: + # Since VaultedValue stores the encrypted representation of the value on which it is tagged, + # it is incorrect to propagate the tag to a value which is not equal to the original. + # If the tag were copied to another value and subsequently serialized as the original encrypted value, + # the result would then differ from the value on which the tag was applied. + + # Comparisons which can trigger an exception are indicative of a bug and should not be handled here. + # For example: + # * When `src` is an undecryptable `EncryptedString` -- it is not valid to apply this tag to that type. + # * When `value` is a `Marker` -- this requires a templating, but vaulted values do not support templating. + + if src == value: # assume the tag was correctly applied to src + return self # same plaintext value, tag propagation with same ciphertext is safe + + return self.get_tag(value) # different value, preserve the existing tag, if any + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class TrustedAsTemplate(AnsibleSingletonTagBase): + """ + Indicates the tagged string is trusted to parse and render as a template. + Do *NOT* apply this tag to data from untrusted sources, as this would allow code injection during templating. + """ + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class SourceWasEncrypted(AnsibleSingletonTagBase): + """ + For internal use only. + Indicates the tagged value was sourced from an encrypted file. + Currently applied only by DataLoader.get_text_file_contents() and by extension DataLoader.load_from_file(). + """ diff --git a/lib/ansible/_internal/_datatag/_utils.py b/lib/ansible/_internal/_datatag/_utils.py new file mode 100644 index 00000000000..bf57ae29ac3 --- /dev/null +++ b/lib/ansible/_internal/_datatag/_utils.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from ansible.module_utils._internal._datatag import AnsibleTagHelper + + +def str_problematic_strip(value: str) -> str: + """ + Return a copy of `value` with leading and trailing whitespace removed. + Used where `str.strip` is needed, but tags must be preserved *AND* the stripping behavior likely shouldn't exist. + If the stripping behavior is non-problematic, use `AnsibleTagHelper.tag_copy` around `str.strip` instead. + """ + if (stripped_value := value.strip()) == value: + return value + + # FUTURE: consider deprecating some/all usages of this method; they generally imply a code smell or pattern we shouldn't be supporting + + stripped_value = AnsibleTagHelper.tag_copy(value, stripped_value) + + return stripped_value diff --git a/lib/ansible/_internal/_datatag/_wrappers.py b/lib/ansible/_internal/_datatag/_wrappers.py new file mode 100644 index 00000000000..51cb4d54635 --- /dev/null +++ b/lib/ansible/_internal/_datatag/_wrappers.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import io +import typing as _t + +from .._wrapt import ObjectProxy +from ...module_utils._internal import _datatag + + +class TaggedStreamWrapper(ObjectProxy): + """ + Janky proxy around IOBase to allow streams to carry tags and support basic interrogation by the tagging API. + Most tagging operations will have undefined behavior for this type. + """ + + _self__ansible_tags_mapping: _datatag._AnsibleTagsMapping + + def __init__(self, stream: io.IOBase, tags: _datatag.AnsibleDatatagBase | _t.Iterable[_datatag.AnsibleDatatagBase]) -> None: + super().__init__(stream) + + tag_list: list[_datatag.AnsibleDatatagBase] + + # noinspection PyProtectedMember + if type(tags) in _datatag._known_tag_types: + tag_list = [tags] # type: ignore[list-item] + else: + tag_list = list(tags) # type: ignore[arg-type] + + self._self__ansible_tags_mapping = _datatag._AnsibleTagsMapping((type(tag), tag) for tag in tag_list) + + @property + def _ansible_tags_mapping(self) -> _datatag._AnsibleTagsMapping: + return self._self__ansible_tags_mapping diff --git a/lib/ansible/_internal/_errors/__init__.py b/lib/ansible/_internal/_errors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/_internal/_errors/_captured.py b/lib/ansible/_internal/_errors/_captured.py new file mode 100644 index 00000000000..736e915625f --- /dev/null +++ b/lib/ansible/_internal/_errors/_captured.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import dataclasses +import typing as t + +from ansible.errors import AnsibleRuntimeError +from ansible.module_utils.common.messages import ErrorSummary, Detail, _dataclass_kwargs + + +class AnsibleCapturedError(AnsibleRuntimeError): + """An exception representing error detail captured in another context where the error detail must be serialized to be preserved.""" + + context: t.ClassVar[str] + + def __init__( + self, + *, + obj: t.Any = None, + error_summary: ErrorSummary, + ) -> None: + super().__init__( + obj=obj, + ) + + self._error_summary = error_summary + + @property + def error_summary(self) -> ErrorSummary: + return self._error_summary + + +class AnsibleResultCapturedError(AnsibleCapturedError): + """An exception representing error detail captured in a foreign context where an action/module result dictionary is involved.""" + + def __init__(self, error_summary: ErrorSummary, result: dict[str, t.Any]) -> None: + super().__init__(error_summary=error_summary) + + self._result = result + + @classmethod + def maybe_raise_on_result(cls, result: dict[str, t.Any]) -> None: + """Normalize the result and raise an exception if the result indicated failure.""" + if error_summary := cls.normalize_result_exception(result): + raise error_summary.error_type(error_summary, result) + + @classmethod + def find_first_remoted_error(cls, exception: BaseException) -> t.Self | None: + """Find the first captured module error in the cause chain, starting with the given exception, returning None if not found.""" + while exception: + if isinstance(exception, cls): + return exception + + exception = exception.__cause__ + + return None + + @classmethod + def normalize_result_exception(cls, result: dict[str, t.Any]) -> CapturedErrorSummary | None: + """ + Normalize the result `exception`, if any, to be a `CapturedErrorSummary` instance. + If a new `CapturedErrorSummary` was created, the `error_type` will be `cls`. + The `exception` key will be removed if falsey. + A `CapturedErrorSummary` instance will be returned if `failed` is truthy. + """ + if type(cls) is AnsibleResultCapturedError: # pylint: disable=unidiomatic-typecheck + raise TypeError('The normalize_result_exception method cannot be called on the AnsibleCapturedError base type, use a derived type.') + + if not isinstance(result, dict): + raise TypeError(f'Malformed result. Received {type(result)} instead of {dict}.') + + failed = result.get('failed') # DTFIX-FUTURE: warn if failed is present and not a bool, or exception is present without failed being True + exception = result.pop('exception', None) + + if not failed and not exception: + return None + + if isinstance(exception, CapturedErrorSummary): + error_summary = exception + elif isinstance(exception, ErrorSummary): + error_summary = CapturedErrorSummary( + details=exception.details, + formatted_traceback=cls._normalize_traceback(exception.formatted_traceback), + error_type=cls, + ) + else: + # translate non-ErrorDetail errors + error_summary = CapturedErrorSummary( + details=(Detail(msg=str(result.get('msg', 'Unknown error.'))),), + formatted_traceback=cls._normalize_traceback(exception), + error_type=cls, + ) + + result.update(exception=error_summary) + + return error_summary if failed else None # even though error detail was normalized, only return it if the result indicated failure + + @classmethod + def _normalize_traceback(cls, value: object | None) -> str | None: + """Normalize the provided traceback value, returning None if it is falsey.""" + if not value: + return None + + value = str(value).rstrip() + + if not value: + return None + + return value + '\n' + + +class AnsibleActionCapturedError(AnsibleResultCapturedError): + """An exception representing error detail sourced directly by an action in its result dictionary.""" + + _default_message = 'Action failed.' + context = 'action' + + +class AnsibleModuleCapturedError(AnsibleResultCapturedError): + """An exception representing error detail captured in a module context and returned from an action's result dictionary.""" + + _default_message = 'Module failed.' + context = 'target' + + +@dataclasses.dataclass(**_dataclass_kwargs) +class CapturedErrorSummary(ErrorSummary): + # DTFIX-RELEASE: where to put this, name, etc. since it shows up in results, it's not exactly private (and contains a type ref to an internal type) + error_type: type[AnsibleResultCapturedError] | None = None diff --git a/lib/ansible/_internal/_errors/_handler.py b/lib/ansible/_internal/_errors/_handler.py new file mode 100644 index 00000000000..94a391c3786 --- /dev/null +++ b/lib/ansible/_internal/_errors/_handler.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import contextlib +import enum +import typing as t + +from ansible.utils.display import Display +from ansible.constants import config + +display = Display() + +# FUTURE: add sanity test to detect use of skip_on_ignore without Skippable (and vice versa) + + +class ErrorAction(enum.Enum): + """Action to take when an error is encountered.""" + + IGNORE = enum.auto() + WARN = enum.auto() + FAIL = enum.auto() + + @classmethod + def from_config(cls, setting: str, variables: dict[str, t.Any] | None = None) -> t.Self: + """Return an `ErrorAction` enum from the specified Ansible config setting.""" + return cls[config.get_config_value(setting, variables=variables).upper()] + + +class _SkipException(BaseException): + """Internal flow control exception for skipping code blocks within a `Skippable` context manager.""" + + def __init__(self) -> None: + super().__init__('Skipping ignored action due to use of `skip_on_ignore`. It is a bug to encounter this message outside of debugging.') + + +class _SkippableContextManager: + """Internal context manager to support flow control for skipping code blocks.""" + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, _exc_val, _exc_tb) -> bool: + if exc_type is None: + raise RuntimeError('A `Skippable` context manager was entered, but a `skip_on_ignore` handler was never invoked.') + + return exc_type is _SkipException # only mask a _SkipException, allow all others to raise + + +Skippable = _SkippableContextManager() +"""Context manager singleton required to enclose `ErrorHandler.handle` invocations when `skip_on_ignore` is `True`.""" + + +class ErrorHandler: + """ + Provides a configurable error handler context manager for a specific list of exception types. + Unhandled errors leaving the context manager can be ignored, treated as warnings, or allowed to raise by setting `ErrorAction`. + """ + + def __init__(self, action: ErrorAction) -> None: + self.action = action + + @contextlib.contextmanager + def handle(self, *args: type[BaseException], skip_on_ignore: bool = False) -> t.Iterator[None]: + """ + Handle the specified exception(s) using the defined error action. + If `skip_on_ignore` is `True`, the body of the context manager will be skipped for `ErrorAction.IGNORE`. + Use of `skip_on_ignore` requires enclosure within the `Skippable` context manager. + """ + if not args: + raise ValueError('At least one exception type is required.') + + if skip_on_ignore and self.action == ErrorAction.IGNORE: + raise _SkipException() # skipping ignored action + + try: + yield + except args as ex: + match self.action: + case ErrorAction.WARN: + display.error_as_warning(msg=None, exception=ex) + case ErrorAction.FAIL: + raise + case _: # ErrorAction.IGNORE + pass + + if skip_on_ignore: + raise _SkipException() # completed skippable action, ensures the `Skippable` context was used + + @classmethod + def from_config(cls, setting: str, variables: dict[str, t.Any] | None = None) -> t.Self: + """Return an `ErrorHandler` instance configured using the specified Ansible config setting.""" + return cls(ErrorAction.from_config(setting, variables=variables)) diff --git a/lib/ansible/_internal/_errors/_utils.py b/lib/ansible/_internal/_errors/_utils.py new file mode 100644 index 00000000000..cd997a0ff57 --- /dev/null +++ b/lib/ansible/_internal/_errors/_utils.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import dataclasses +import itertools +import pathlib +import sys +import textwrap +import typing as t + +from ansible.module_utils.common.messages import Detail, ErrorSummary +from ansible._internal._datatag._tags import Origin +from ansible.module_utils._internal import _ambient_context, _traceback +from ansible import errors + +if t.TYPE_CHECKING: + from ansible.utils.display import Display + + +class RedactAnnotatedSourceContext(_ambient_context.AmbientContextBase): + """ + When active, this context will redact annotated source lines, showing only the origin. + """ + + +def _dedupe_and_concat_message_chain(message_parts: list[str]) -> str: + message_parts = list(reversed(message_parts)) + + message = message_parts.pop(0) + + for message_part in message_parts: + # avoid duplicate messages where the cause was already concatenated to the exception message + if message_part.endswith(message): + message = message_part + else: + message = concat_message(message_part, message) + + return message + + +def _collapse_error_details(error_details: t.Sequence[Detail]) -> list[Detail]: + """ + Return a potentially modified error chain, with redundant errors collapsed into previous error(s) in the chain. + This reduces the verbosity of messages by eliminating repetition when multiple errors in the chain share the same contextual information. + """ + previous_error = error_details[0] + previous_warnings: list[str] = [] + collapsed_error_details: list[tuple[Detail, list[str]]] = [(previous_error, previous_warnings)] + + for error in error_details[1:]: + details_present = error.formatted_source_context or error.help_text + details_changed = error.formatted_source_context != previous_error.formatted_source_context or error.help_text != previous_error.help_text + + if details_present and details_changed: + previous_error = error + previous_warnings = [] + collapsed_error_details.append((previous_error, previous_warnings)) + else: + previous_warnings.append(error.msg) + + final_error_details: list[Detail] = [] + + for error, messages in collapsed_error_details: + final_error_details.append(dataclasses.replace(error, msg=_dedupe_and_concat_message_chain([error.msg] + messages))) + + return final_error_details + + +def _get_cause(exception: BaseException) -> BaseException | None: + # deprecated: description='remove support for orig_exc (deprecated in 2.23)' core_version='2.27' + + if not isinstance(exception, errors.AnsibleError): + return exception.__cause__ + + if exception.__cause__: + if exception.orig_exc and exception.orig_exc is not exception.__cause__: + _get_display().warning( + msg=f"The `orig_exc` argument to `{type(exception).__name__}` was given, but differed from the cause given by `raise ... from`.", + ) + + return exception.__cause__ + + if exception.orig_exc: + # encourage the use of `raise ... from` before deprecating `orig_exc` + _get_display().warning(msg=f"The `orig_exc` argument to `{type(exception).__name__}` was given without using `raise ... from orig_exc`.") + + return exception.orig_exc + + return None + + +class _TemporaryDisplay: + # DTFIX-FUTURE: generalize this and hide it in the display module so all users of Display can benefit + + @staticmethod + def warning(*args, **kwargs): + print(f'FALLBACK WARNING: {args} {kwargs}', file=sys.stderr) + + @staticmethod + def deprecated(*args, **kwargs): + print(f'FALLBACK DEPRECATION: {args} {kwargs}', file=sys.stderr) + + +def _get_display() -> Display | _TemporaryDisplay: + try: + from ansible.utils.display import Display + except ImportError: + return _TemporaryDisplay() + + return Display() + + +def _create_error_summary(exception: BaseException, event: _traceback.TracebackEvent | None = None) -> ErrorSummary: + from . import _captured # avoid circular import due to AnsibleError import + + current_exception: BaseException | None = exception + error_details: list[Detail] = [] + + if event: + formatted_traceback = _traceback.maybe_extract_traceback(exception, event) + else: + formatted_traceback = None + + while current_exception: + if isinstance(current_exception, errors.AnsibleError): + include_cause_message = current_exception._include_cause_message + edc = Detail( + msg=current_exception._original_message.strip(), + formatted_source_context=current_exception._formatted_source_context, + help_text=current_exception._help_text, + ) + else: + include_cause_message = True + edc = Detail( + msg=str(current_exception).strip(), + ) + + error_details.append(edc) + + if isinstance(current_exception, _captured.AnsibleCapturedError): + detail = current_exception.error_summary + error_details.extend(detail.details) + + if formatted_traceback and detail.formatted_traceback: + formatted_traceback = ( + f'{detail.formatted_traceback}\n' + f'The {current_exception.context} exception above was the direct cause of the following controller exception:\n\n' + f'{formatted_traceback}' + ) + + if not include_cause_message: + break + + current_exception = _get_cause(current_exception) + + return ErrorSummary(details=tuple(error_details), formatted_traceback=formatted_traceback) + + +def concat_message(left: str, right: str) -> str: + """Normalize `left` by removing trailing punctuation and spaces before appending new punctuation and `right`.""" + return f'{left.rstrip(". ")}: {right}' + + +def get_chained_message(exception: BaseException) -> str: + """ + Return the full chain of exception messages by concatenating the cause(s) until all are exhausted. + """ + error_summary = _create_error_summary(exception) + message_parts = [edc.msg for edc in error_summary.details] + + return _dedupe_and_concat_message_chain(message_parts) + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class SourceContext: + origin: Origin + annotated_source_lines: list[str] + target_line: str | None + + def __str__(self) -> str: + msg_lines = [f'Origin: {self.origin}'] + + if self.annotated_source_lines: + msg_lines.append('') + msg_lines.extend(self.annotated_source_lines) + + return '\n'.join(msg_lines) + + @classmethod + def from_value(cls, value: t.Any) -> SourceContext | None: + """Attempt to retrieve source and render a contextual indicator from the value's origin (if any).""" + if value is None: + return None + + if isinstance(value, Origin): + origin = value + value = None + else: + origin = Origin.get_tag(value) + + if RedactAnnotatedSourceContext.current(optional=True): + return cls.error('content redacted') + + if origin and origin.path: + return cls.from_origin(origin) + + # DTFIX-RELEASE: redaction context may not be sufficient to avoid secret disclosure without SensitiveData and other enhancements + if value is None: + truncated_value = None + annotated_source_lines = [] + else: + # DTFIX-FUTURE: cleanup/share width + try: + value = str(value) + except Exception as ex: + value = f'<< context unavailable: {ex} >>' + + truncated_value = textwrap.shorten(value, width=120) + annotated_source_lines = [truncated_value] + + return SourceContext( + origin=origin or Origin.UNKNOWN, + annotated_source_lines=annotated_source_lines, + target_line=truncated_value, + ) + + @staticmethod + def error(message: str | None, origin: Origin | None = None) -> SourceContext: + return SourceContext( + origin=origin, + annotated_source_lines=[f'(source not shown: {message})'] if message else [], + target_line=None, + ) + + @classmethod + def from_origin(cls, origin: Origin) -> SourceContext: + """Attempt to retrieve source and render a contextual indicator of an error location.""" + from ansible.parsing.vault import is_encrypted # avoid circular import + + # DTFIX-FUTURE: support referencing the column after the end of the target line, so we can indicate where a missing character (quote) needs to be added + # this is also useful for cases like end-of-stream reported by the YAML parser + + # DTFIX-FUTURE: Implement line wrapping and match annotated line width to the terminal display width. + + context_line_count: t.Final = 2 + max_annotated_line_width: t.Final = 120 + truncation_marker: t.Final = '...' + + target_line_num = origin.line_num + + if RedactAnnotatedSourceContext.current(optional=True): + return cls.error('content redacted', origin) + + if not target_line_num or target_line_num < 1: + return cls.error(None, origin) # message omitted since lack of line number is obvious from pos + + start_line_idx = max(0, (target_line_num - 1) - context_line_count) # if near start of file + target_col_num = origin.col_num + + try: + with pathlib.Path(origin.path).open() as src: + first_line = src.readline() + lines = list(itertools.islice(itertools.chain((first_line,), src), start_line_idx, target_line_num)) + except Exception as ex: + return cls.error(type(ex).__name__, origin) + + if is_encrypted(first_line): + return cls.error('content encrypted', origin) + + if len(lines) != target_line_num - start_line_idx: + return cls.error('file truncated', origin) + + annotated_source_lines = [] + + line_label_width = len(str(target_line_num)) + max_src_line_len = max_annotated_line_width - line_label_width - 1 + + usable_line_len = max_src_line_len + + for line_num, line in enumerate(lines, start_line_idx + 1): + line = line.rstrip('\n') # universal newline default mode on `open` ensures we'll never see anything but \n + line = line.replace('\t', ' ') # mixed tab/space handling is intentionally disabled since we're both format and display config agnostic + + if len(line) > max_src_line_len: + line = line[: max_src_line_len - len(truncation_marker)] + truncation_marker + usable_line_len = max_src_line_len - len(truncation_marker) + + annotated_source_lines.append(f'{str(line_num).rjust(line_label_width)}{" " if line else ""}{line}') + + if target_col_num and usable_line_len >= target_col_num >= 1: + column_marker = f'column {target_col_num}' + + target_col_idx = target_col_num - 1 + + if target_col_idx + 2 + len(column_marker) > max_src_line_len: + column_marker = f'{" " * (target_col_idx - len(column_marker) - 1)}{column_marker} ^' + else: + column_marker = f'{" " * target_col_idx}^ {column_marker}' + + column_marker = f'{" " * line_label_width} {column_marker}' + + annotated_source_lines.append(column_marker) + elif target_col_num is None: + underline_length = len(annotated_source_lines[-1]) - line_label_width - 1 + annotated_source_lines.append(f'{" " * line_label_width} {"^" * underline_length}') + + return SourceContext( + origin=origin, + annotated_source_lines=annotated_source_lines, + target_line=lines[-1].rstrip('\n'), # universal newline default mode on `open` ensures we'll never see anything but \n + ) diff --git a/lib/ansible/_internal/_json/__init__.py b/lib/ansible/_internal/_json/__init__.py new file mode 100644 index 00000000000..81cb409aeb9 --- /dev/null +++ b/lib/ansible/_internal/_json/__init__.py @@ -0,0 +1,160 @@ +"""Internal utilities for serialization and deserialization.""" + +# DTFIX-RELEASE: most of this isn't JSON specific, find a better home + +from __future__ import annotations + +import json +import typing as t + +from ansible.errors import AnsibleVariableTypeError + +from ansible.module_utils._internal._datatag import ( + _ANSIBLE_ALLOWED_MAPPING_VAR_TYPES, + _ANSIBLE_ALLOWED_NON_SCALAR_COLLECTION_VAR_TYPES, + _ANSIBLE_ALLOWED_VAR_TYPES, + _AnsibleTaggedStr, + AnsibleTagHelper, +) +from ansible.module_utils._internal._json._profiles import _tagless +from ansible.parsing.vault import EncryptedString +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate +from ansible.module_utils import _internal + +_T = t.TypeVar('_T') +_sentinel = object() + + +class HasCurrent(t.Protocol): + """Utility protocol for mixin type safety.""" + + _current: t.Any + + +class StateTrackingMixIn(HasCurrent): + """Mixin for use with `AnsibleVariableVisitor` to track current visitation context.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._stack: list[t.Any] = [] + + def __enter__(self) -> None: + self._stack.append(self._current) + + def __exit__(self, *_args, **_kwargs) -> None: + self._stack.pop() + + def _get_stack(self) -> list[t.Any]: + if not self._stack: + return [] + + return self._stack[1:] + [self._current] + + +class AnsibleVariableVisitor: + """Utility visitor base class to recursively apply various behaviors and checks to variable object graphs.""" + + def __init__( + self, + *, + trusted_as_template: bool = False, + origin: Origin | None = None, + convert_mapping_to_dict: bool = False, + convert_sequence_to_list: bool = False, + convert_custom_scalars: bool = False, + allow_encrypted_string: bool = False, + ): + super().__init__() # supports StateTrackingMixIn + + self.trusted_as_template = trusted_as_template + self.origin = origin + self.convert_mapping_to_dict = convert_mapping_to_dict + self.convert_sequence_to_list = convert_sequence_to_list + self.convert_custom_scalars = convert_custom_scalars + self.allow_encrypted_string = allow_encrypted_string + + self._current: t.Any = None # supports StateTrackingMixIn + + def __enter__(self) -> t.Any: + """No-op context manager dispatcher (delegates to mixin behavior if present).""" + if func := getattr(super(), '__enter__', None): + func() + + def __exit__(self, *args, **kwargs) -> t.Any: + """No-op context manager dispatcher (delegates to mixin behavior if present).""" + if func := getattr(super(), '__exit__', None): + func(*args, **kwargs) + + def visit(self, value: _T) -> _T: + """ + Enforces Ansible's variable type system restrictions before a var is accepted in inventory. Also, conditionally implements template trust + compatibility, depending on the plugin's declared understanding (or lack thereof). This always recursively copies inputs to fully isolate + inventory data from what the plugin provided, and prevent any later mutation. + """ + return self._visit(None, value) + + def _early_visit(self, value, value_type) -> t.Any: + """Overridable hook point to allow custom string handling in derived visitors.""" + if value_type in (str, _AnsibleTaggedStr): + # apply compatibility behavior + if self.trusted_as_template: + result = TrustedAsTemplate().tag(value) + else: + result = value + else: + result = _sentinel + + return result + + def _visit(self, key: t.Any, value: _T) -> _T: + """Internal implementation to recursively visit a data structure's contents.""" + self._current = key # supports StateTrackingMixIn + + value_type = type(value) + + result: _T + + # DTFIX-RELEASE: the visitor is ignoring dict/mapping keys except for debugging and schema-aware checking, it should be doing type checks on keys + # DTFIX-RELEASE: some type lists being consulted (the ones from datatag) are probably too permissive, and perhaps should not be dynamic + + if (result := self._early_visit(value, value_type)) is not _sentinel: + pass + # DTFIX-RELEASE: de-duplicate and optimize; extract inline generator expressions and fallback function or mapping for native type calculation? + elif value_type in _ANSIBLE_ALLOWED_MAPPING_VAR_TYPES: # check mappings first, because they're also collections + with self: # supports StateTrackingMixIn + result = AnsibleTagHelper.tag_copy(value, ((k, self._visit(k, v)) for k, v in value.items()), value_type=value_type) + elif value_type in _ANSIBLE_ALLOWED_NON_SCALAR_COLLECTION_VAR_TYPES: + with self: # supports StateTrackingMixIn + result = AnsibleTagHelper.tag_copy(value, (self._visit(k, v) for k, v in enumerate(t.cast(t.Iterable, value))), value_type=value_type) + elif self.allow_encrypted_string and isinstance(value, EncryptedString): + return value # type: ignore[return-value] # DTFIX-RELEASE: this should probably only be allowed for values in dict, not keys (set, dict) + elif self.convert_mapping_to_dict and _internal.is_intermediate_mapping(value): + with self: # supports StateTrackingMixIn + result = {k: self._visit(k, v) for k, v in value.items()} # type: ignore[assignment] + elif self.convert_sequence_to_list and _internal.is_intermediate_iterable(value): + with self: # supports StateTrackingMixIn + result = [self._visit(k, v) for k, v in enumerate(t.cast(t.Iterable, value))] # type: ignore[assignment] + elif self.convert_custom_scalars and isinstance(value, str): + result = str(value) # type: ignore[assignment] + elif self.convert_custom_scalars and isinstance(value, float): + result = float(value) # type: ignore[assignment] + elif self.convert_custom_scalars and isinstance(value, int) and not isinstance(value, bool): + result = int(value) # type: ignore[assignment] + else: + if value_type not in _ANSIBLE_ALLOWED_VAR_TYPES: + raise AnsibleVariableTypeError.from_value(obj=value) + + # supported scalar type that requires no special handling, just return as-is + result = value + + if self.origin and not Origin.is_tagged_on(result): + # apply shared instance default origin tag + result = self.origin.tag(result) + + return result + + +def json_dumps_formatted(value: object) -> str: + """Return a JSON dump of `value` with formatting and keys sorted.""" + return json.dumps(value, cls=_tagless.Encoder, sort_keys=True, indent=4) diff --git a/lib/ansible/_internal/_json/_legacy_encoder.py b/lib/ansible/_internal/_json/_legacy_encoder.py new file mode 100644 index 00000000000..431c245a1c9 --- /dev/null +++ b/lib/ansible/_internal/_json/_legacy_encoder.py @@ -0,0 +1,34 @@ +from __future__ import annotations as _annotations + +import typing as _t + +from ansible.module_utils._internal._json import _profiles +from ansible._internal._json._profiles import _legacy +from ansible.parsing import vault as _vault + + +class LegacyControllerJSONEncoder(_legacy.Encoder): + """Compatibility wrapper over `legacy` profile JSON encoder to support trust stripping and vault value plaintext conversion.""" + + def __init__(self, preprocess_unsafe: bool = False, vault_to_text: bool = False, _decode_bytes: bool = False, **kwargs) -> None: + self._preprocess_unsafe = preprocess_unsafe + self._vault_to_text = vault_to_text + self._decode_bytes = _decode_bytes + + super().__init__(**kwargs) + + def default(self, o: _t.Any) -> _t.Any: + """Hooked default that can conditionally bypass base encoder behavior based on this instance's config.""" + if type(o) is _profiles._WrappedValue: # pylint: disable=unidiomatic-typecheck + o = o.wrapped + + if not self._preprocess_unsafe and type(o) is _legacy._Untrusted: # pylint: disable=unidiomatic-typecheck + return o.value # if not emitting unsafe markers, bypass custom unsafe serialization and just return the raw value + + if self._vault_to_text and type(o) is _vault.EncryptedString: # pylint: disable=unidiomatic-typecheck + return str(o) # decrypt and return the plaintext (or fail trying) + + if self._decode_bytes and isinstance(o, bytes): + return o.decode(errors='surrogateescape') # backward compatibility with `ansible.module_utils.basic.jsonify` + + return super().default(o) diff --git a/lib/ansible/_internal/_json/_profiles/__init__.py b/lib/ansible/_internal/_json/_profiles/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/_internal/_json/_profiles/_cache_persistence.py b/lib/ansible/_internal/_json/_profiles/_cache_persistence.py new file mode 100644 index 00000000000..a6c16e1a794 --- /dev/null +++ b/lib/ansible/_internal/_json/_profiles/_cache_persistence.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import datetime as _datetime + +from ansible.module_utils._internal import _datatag +from ansible.module_utils._internal._json import _profiles +from ansible.parsing import vault as _vault +from ansible._internal._datatag import _tags + + +class _Profile(_profiles._JSONSerializationProfile): + """Profile for external cache persistence of inventory/fact data that preserves most tags.""" + + serialize_map = {} + schema_id = 1 + + @classmethod + def post_init(cls, **kwargs): + cls.allowed_ansible_serializable_types = ( + _profiles._common_module_types + | _profiles._common_module_response_types + | { + _datatag._AnsibleTaggedDate, + _datatag._AnsibleTaggedTime, + _datatag._AnsibleTaggedDateTime, + _datatag._AnsibleTaggedStr, + _datatag._AnsibleTaggedInt, + _datatag._AnsibleTaggedFloat, + _datatag._AnsibleTaggedList, + _datatag._AnsibleTaggedSet, + _datatag._AnsibleTaggedTuple, + _datatag._AnsibleTaggedDict, + _tags.SourceWasEncrypted, + _tags.Origin, + _tags.TrustedAsTemplate, + _vault.EncryptedString, + _vault.VaultedValue, + } + ) + + cls.serialize_map = { + set: cls.serialize_as_list, + tuple: cls.serialize_as_list, + _datetime.date: _datatag.AnsibleSerializableDate, + _datetime.time: _datatag.AnsibleSerializableTime, + _datetime.datetime: _datatag.AnsibleSerializableDateTime, + } + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/_internal/_json/_profiles/_inventory_legacy.py b/lib/ansible/_internal/_json/_profiles/_inventory_legacy.py new file mode 100644 index 00000000000..aa9c8ea1057 --- /dev/null +++ b/lib/ansible/_internal/_json/_profiles/_inventory_legacy.py @@ -0,0 +1,40 @@ +""" +Backwards compatibility profile for serialization for persisted ansible-inventory output. +Behavior is equivalent to pre 2.18 `AnsibleJSONEncoder` with vault_to_text=True. +""" + +from __future__ import annotations + +from ... import _json +from . import _legacy + + +class _InventoryVariableVisitor(_legacy._LegacyVariableVisitor, _json.StateTrackingMixIn): + """State-tracking visitor implementation that only applies trust to `_meta.hostvars` and `vars` inventory values.""" + + # DTFIX-RELEASE: does the variable visitor need to support conversion of sequence/mapping for inventory? + + @property + def _allow_trust(self) -> bool: + stack = self._get_stack() + + if len(stack) >= 4 and stack[:2] == ['_meta', 'hostvars']: + return True + + if len(stack) >= 3 and stack[1] == 'vars': + return True + + return False + + +class _Profile(_legacy._Profile): + visitor_type = _InventoryVariableVisitor + encode_strings_as_utf8 = True + + +class Encoder(_legacy.Encoder): + _profile = _Profile + + +class Decoder(_legacy.Decoder): + _profile = _Profile diff --git a/lib/ansible/_internal/_json/_profiles/_legacy.py b/lib/ansible/_internal/_json/_profiles/_legacy.py new file mode 100644 index 00000000000..2b333e6da12 --- /dev/null +++ b/lib/ansible/_internal/_json/_profiles/_legacy.py @@ -0,0 +1,198 @@ +""" +Backwards compatibility profile for serialization other than inventory (which should use inventory_legacy for backward-compatible trust behavior). +Behavior is equivalent to pre 2.18 `AnsibleJSONEncoder` with vault_to_text=True. +""" + +from __future__ import annotations as _annotations + +import datetime as _datetime +import typing as _t + +from ansible._internal._datatag import _tags +from ansible.module_utils._internal import _datatag +from ansible.module_utils._internal._json import _profiles +from ansible.parsing import vault as _vault + +from ... import _json + + +class _Untrusted: + """ + Temporarily wraps strings which are not trusted for templating. + Used before serialization of strings not tagged TrustedAsTemplate when trust inversion is enabled and trust is allowed in the string's context. + Used during deserialization of `__ansible_unsafe` strings to indicate they should not be tagged TrustedAsTemplate. + """ + + __slots__ = ('value',) + + def __init__(self, value: str) -> None: + self.value = value + + +class _LegacyVariableVisitor(_json.AnsibleVariableVisitor): + """Variable visitor that supports optional trust inversion for legacy serialization.""" + + def __init__( + self, + *, + trusted_as_template: bool = False, + invert_trust: bool = False, + origin: _tags.Origin | None = None, + convert_mapping_to_dict: bool = False, + convert_sequence_to_list: bool = False, + convert_custom_scalars: bool = False, + ): + super().__init__( + trusted_as_template=trusted_as_template, + origin=origin, + convert_mapping_to_dict=convert_mapping_to_dict, + convert_sequence_to_list=convert_sequence_to_list, + convert_custom_scalars=convert_custom_scalars, + allow_encrypted_string=True, + ) + + self.invert_trust = invert_trust + + if trusted_as_template and invert_trust: + raise ValueError('trusted_as_template is mutually exclusive with invert_trust') + + @property + def _allow_trust(self) -> bool: + """ + This profile supports trust application in all contexts. + Derived implementations can override this behavior for application-dependent/schema-aware trust. + """ + return True + + def _early_visit(self, value, value_type) -> _t.Any: + """Similar to base implementation, but supports an intermediate wrapper for trust inversion.""" + if value_type in (str, _datatag._AnsibleTaggedStr): + # apply compatibility behavior + if self.trusted_as_template and self._allow_trust: + result = _tags.TrustedAsTemplate().tag(value) + elif self.invert_trust and not _tags.TrustedAsTemplate.is_tagged_on(value) and self._allow_trust: + result = _Untrusted(value) + else: + result = value + elif value_type is _Untrusted: + result = value.value + else: + result = _json._sentinel + + return result + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + visitor_type = _LegacyVariableVisitor + + @classmethod + def serialize_untrusted(cls, value: _Untrusted) -> dict[str, str] | str: + return dict( + __ansible_unsafe=_datatag.AnsibleTagHelper.untag(value.value), + ) + + @classmethod + def serialize_tagged_str(cls, value: _datatag.AnsibleTaggedObject) -> _t.Any: + if ciphertext := _vault.VaultHelper.get_ciphertext(value, with_tags=False): + return dict( + __ansible_vault=ciphertext, + ) + + return _datatag.AnsibleTagHelper.untag(value) + + @classmethod + def deserialize_unsafe(cls, value: dict[str, _t.Any]) -> _Untrusted: + ansible_unsafe = value['__ansible_unsafe'] + + if type(ansible_unsafe) is not str: # pylint: disable=unidiomatic-typecheck + raise TypeError(f"__ansible_unsafe is {type(ansible_unsafe)} not {str}") + + return _Untrusted(ansible_unsafe) + + @classmethod + def deserialize_vault(cls, value: dict[str, _t.Any]) -> _vault.EncryptedString: + ansible_vault = value['__ansible_vault'] + + if type(ansible_vault) is not str: # pylint: disable=unidiomatic-typecheck + raise TypeError(f"__ansible_vault is {type(ansible_vault)} not {str}") + + encrypted_string = _vault.EncryptedString(ciphertext=ansible_vault) + + return encrypted_string + + @classmethod + def serialize_encrypted_string(cls, value: _vault.EncryptedString) -> dict[str, str]: + return dict( + __ansible_vault=_vault.VaultHelper.get_ciphertext(value, with_tags=False), + ) + + @classmethod + def post_init(cls) -> None: + cls.serialize_map = { + set: cls.serialize_as_list, + tuple: cls.serialize_as_list, + _datetime.date: cls.serialize_as_isoformat, # existing devel behavior + _datetime.time: cls.serialize_as_isoformat, # always failed pre-2.18, so okay to include for consistency + _datetime.datetime: cls.serialize_as_isoformat, # existing devel behavior + _datatag._AnsibleTaggedDate: cls.discard_tags, + _datatag._AnsibleTaggedTime: cls.discard_tags, + _datatag._AnsibleTaggedDateTime: cls.discard_tags, + _vault.EncryptedString: cls.serialize_encrypted_string, + _datatag._AnsibleTaggedStr: cls.serialize_tagged_str, # for VaultedValue tagged str + _datatag._AnsibleTaggedInt: cls.discard_tags, + _datatag._AnsibleTaggedFloat: cls.discard_tags, + _datatag._AnsibleTaggedList: cls.discard_tags, + _datatag._AnsibleTaggedSet: cls.discard_tags, + _datatag._AnsibleTaggedTuple: cls.discard_tags, + _datatag._AnsibleTaggedDict: cls.discard_tags, + _Untrusted: cls.serialize_untrusted, # equivalent to AnsibleJSONEncoder(preprocess_unsafe=True) in devel + } + + cls.deserialize_map = { + '__ansible_unsafe': cls.deserialize_unsafe, + '__ansible_vault': cls.deserialize_vault, + } + + @classmethod + def pre_serialize(cls, encoder: Encoder, o: _t.Any) -> _t.Any: + # DTFIX-RELEASE: these conversion args probably aren't needed + avv = cls.visitor_type(invert_trust=True, convert_mapping_to_dict=True, convert_sequence_to_list=True, convert_custom_scalars=True) + + return avv.visit(o) + + @classmethod + def post_deserialize(cls, decoder: Decoder, o: _t.Any) -> _t.Any: + avv = cls.visitor_type(trusted_as_template=decoder._trusted_as_template, origin=decoder._origin) + + return avv.visit(o) + + @classmethod + def handle_key(cls, k: _t.Any) -> _t.Any: + if isinstance(k, str): + return k + + # DTFIX-RELEASE: decide if this is a deprecation warning, error, or what? + # Non-string variable names have been disallowed by set_fact and other things since at least 2021. + # DTFIX-RELEASE: document why this behavior is here, also verify the legacy tagless use case doesn't need this same behavior + return str(k) + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + # NB: these can only be sampled properly when loading strings, eg, `json.loads`; the global `json.load` function does not expose the file-like to us + self._origin: _tags.Origin | None = None + self._trusted_as_template: bool = False + + def raw_decode(self, s: str, idx: int = 0) -> tuple[_t.Any, int]: + self._origin = _tags.Origin.get_tag(s) + self._trusted_as_template = _tags.TrustedAsTemplate.is_tagged_on(s) + + return super().raw_decode(s, idx) diff --git a/lib/ansible/_internal/_locking.py b/lib/ansible/_internal/_locking.py new file mode 100644 index 00000000000..1b04fa37c82 --- /dev/null +++ b/lib/ansible/_internal/_locking.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import contextlib +import fcntl +import typing as t + + +@contextlib.contextmanager +def named_mutex(path: str) -> t.Iterator[None]: + """ + Lightweight context manager wrapper over `fcntl.flock` to provide IPC locking via a shared filename. + Entering the context manager blocks until the lock is acquired. + The lock file will be created automatically, but creation of the parent directory and deletion of the lockfile are the caller's responsibility. + """ + with open(path, 'a') as file: + fcntl.flock(file, fcntl.LOCK_EX) + + try: + yield + finally: + fcntl.flock(file, fcntl.LOCK_UN) diff --git a/lib/ansible/_internal/_plugins/__init__.py b/lib/ansible/_internal/_plugins/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/_internal/_plugins/_cache.py b/lib/ansible/_internal/_plugins/_cache.py new file mode 100644 index 00000000000..463b0a8ed66 --- /dev/null +++ b/lib/ansible/_internal/_plugins/_cache.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import functools +import json +import json.encoder +import json.decoder +import typing as t + +from .._wrapt import ObjectProxy +from .._json._profiles import _cache_persistence + + +class PluginInterposer(ObjectProxy): + """Proxies a Cache plugin instance to implement transparent encapsulation of serialized Ansible internal data types.""" + + _PAYLOAD_KEY = '__payload__' + """The key used to store the serialized payload.""" + + def get(self, key: str) -> dict[str, object]: + return self._decode(self.__wrapped__.get(self._get_key(key))) + + def set(self, key: str, value: dict[str, object]) -> None: + self.__wrapped__.set(self._get_key(key), self._encode(value)) + + def keys(self) -> t.Sequence[str]: + return [k for k in (self._restore_key(k) for k in self.__wrapped__.keys()) if k is not None] + + def contains(self, key: t.Any) -> bool: + return self.__wrapped__.contains(self._get_key(key)) + + def delete(self, key: str) -> None: + self.__wrapped__.delete(self._get_key(key)) + + @classmethod + def _restore_key(cls, wrapped_key: str) -> str | None: + prefix = cls._get_wrapped_key_prefix() + + if not wrapped_key.startswith(prefix): + return None + + return wrapped_key[len(prefix) :] + + @classmethod + @functools.cache + def _get_wrapped_key_prefix(cls) -> str: + return f's{_cache_persistence._Profile.schema_id}_' + + @classmethod + def _get_key(cls, key: str) -> str: + """Augment the supplied key with a schema identifier to allow for side-by-side caching across incompatible schemas.""" + return f'{cls._get_wrapped_key_prefix()}{key}' + + def _encode(self, value: dict[str, object]) -> dict[str, object]: + return {self._PAYLOAD_KEY: json.dumps(value, cls=_cache_persistence.Encoder)} + + def _decode(self, value: dict[str, t.Any]) -> dict[str, object]: + return json.loads(value[self._PAYLOAD_KEY], cls=_cache_persistence.Decoder) diff --git a/lib/ansible/_internal/_task.py b/lib/ansible/_internal/_task.py new file mode 100644 index 00000000000..6a5e8a63f8b --- /dev/null +++ b/lib/ansible/_internal/_task.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import dataclasses +import typing as t + +from collections import abc as c + +from ansible import constants +from ansible._internal._templating import _engine +from ansible._internal._templating._chain_templar import ChainTemplar +from ansible.errors import AnsibleError +from ansible.module_utils._internal._ambient_context import AmbientContextBase +from ansible.module_utils.datatag import native_type_name +from ansible.parsing import vault as _vault +from ansible.utils.display import Display + +if t.TYPE_CHECKING: + from ansible.playbook.task import Task + + +@dataclasses.dataclass +class TaskContext(AmbientContextBase): + """Ambient context that wraps task execution on workers. It provides access to the currently executing task.""" + + task: Task + + +TaskArgsFinalizerCallback = t.Callable[[str, t.Any, _engine.TemplateEngine, t.Any], t.Any] +"""Type alias for the shape of the `ActionBase.finalize_task_arg` method.""" + + +class TaskArgsChainTemplar(ChainTemplar): + """ + A ChainTemplar that carries a user-provided context object, optionally provided by `ActionBase.get_finalize_task_args_context`. + TaskArgsFinalizer provides the context to each `ActionBase.finalize_task_arg` call to allow for more complex/stateful customization. + """ + + def __init__(self, *sources: c.Mapping, templar: _engine.TemplateEngine, callback: TaskArgsFinalizerCallback, context: t.Any) -> None: + super().__init__(*sources, templar=templar) + + self.callback = callback + self.context = context + + def template(self, key: t.Any, value: t.Any) -> t.Any: + return self.callback(key, value, self.templar, self.context) + + +class TaskArgsFinalizer: + """Invoked during task args finalization; allows actions to override default arg processing (e.g., templating).""" + + def __init__(self, *args: c.Mapping[str, t.Any] | str | None, templar: _engine.TemplateEngine) -> None: + self._args_layers = [arg for arg in args if arg is not None] + self._templar = templar + + def finalize(self, callback: TaskArgsFinalizerCallback, context: t.Any) -> dict[str, t.Any]: + resolved_layers: list[c.Mapping[str, t.Any]] = [] + + for layer in self._args_layers: + if isinstance(layer, (str, _vault.EncryptedString)): # EncryptedString can hide a template + if constants.config.get_config_value('INJECT_FACTS_AS_VARS'): + Display().warning( + "Using a template for task args is unsafe in some situations " + "(see https://docs.ansible.com/ansible/devel/reference_appendices/faq.html#argsplat-unsafe).", + obj=layer, + ) + + resolved_layer = self._templar.resolve_to_container(layer, options=_engine.TemplateOptions(value_for_omit={})) + else: + resolved_layer = layer + + if not isinstance(resolved_layer, dict): + raise AnsibleError(f'Task args must resolve to a {native_type_name(dict)!r} not {native_type_name(resolved_layer)!r}.', obj=layer) + + resolved_layers.append(resolved_layer) + + ct = TaskArgsChainTemplar(*reversed(resolved_layers), templar=self._templar, callback=callback, context=context) + + return ct.as_dict() diff --git a/lib/ansible/_internal/_templating/__init__.py b/lib/ansible/_internal/_templating/__init__.py new file mode 100644 index 00000000000..e2fd19558fc --- /dev/null +++ b/lib/ansible/_internal/_templating/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from jinja2 import __version__ as _jinja2_version + +# DTFIX-FUTURE: sanity test to ensure this doesn't drift from requirements +_MINIMUM_JINJA_VERSION = (3, 1) +_CURRENT_JINJA_VERSION = tuple(map(int, _jinja2_version.split('.', maxsplit=2)[:2])) + +if _CURRENT_JINJA_VERSION < _MINIMUM_JINJA_VERSION: + raise RuntimeError(f'Jinja version {".".join(map(str, _MINIMUM_JINJA_VERSION))} or higher is required (current version {_jinja2_version}).') diff --git a/lib/ansible/_internal/_templating/_access.py b/lib/ansible/_internal/_templating/_access.py new file mode 100644 index 00000000000..d69a92df9fc --- /dev/null +++ b/lib/ansible/_internal/_templating/_access.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import abc +import typing as t + +from contextvars import ContextVar + +from ansible.module_utils._internal._datatag import AnsibleTagHelper + + +class NotifiableAccessContextBase(metaclass=abc.ABCMeta): + """Base class for a context manager that, when active, receives notification of managed access for types/tags in which it has registered an interest.""" + + _type_interest: t.FrozenSet[type] = frozenset() + """Set of types (including tag types) for which this context will be notified upon access.""" + + _mask: t.ClassVar[bool] = False + """When true, only the innermost (most recently created) context of this type will be notified.""" + + def __enter__(self): + # noinspection PyProtectedMember + AnsibleAccessContext.current()._register_interest(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + # noinspection PyProtectedMember + AnsibleAccessContext.current()._unregister_interest(self) + return None + + @abc.abstractmethod + def _notify(self, o: t.Any) -> t.Any: + """Derived classes implement custom notification behavior when a registered type or tag is accessed.""" + + +class AnsibleAccessContext: + """ + Broker object for managed access registration and notification. + Each thread or other logical callstack has a dedicated `AnsibleAccessContext` object with which `NotifiableAccessContext` objects can register interest. + When a managed access occurs on an object, each active `NotifiableAccessContext` within the current callstack that has registered interest in that + object's type or a tag present on it will be notified. + """ + + _contextvar: t.ClassVar[ContextVar[AnsibleAccessContext]] = ContextVar('AnsibleAccessContext') + + @staticmethod + def current() -> AnsibleAccessContext: + """Creates or retrieves an `AnsibleAccessContext` for the current logical callstack.""" + try: + ctx: AnsibleAccessContext = AnsibleAccessContext._contextvar.get() + except LookupError: + # didn't exist; create it + ctx = AnsibleAccessContext() + AnsibleAccessContext._contextvar.set(ctx) # we ignore the token, since this should live for the life of the thread/async ctx + + return ctx + + def __init__(self) -> None: + self._notify_contexts: list[NotifiableAccessContextBase] = [] + + def _register_interest(self, context: NotifiableAccessContextBase) -> None: + self._notify_contexts.append(context) + + def _unregister_interest(self, context: NotifiableAccessContextBase) -> None: + ctx = self._notify_contexts.pop() + + if ctx is not context: + raise RuntimeError(f'Out-of-order context deactivation detected. Found {ctx} instead of {context}.') + + def access(self, value: t.Any) -> None: + """Notify all contexts which have registered interest in the given value that it is being accessed.""" + if not self._notify_contexts: + return + + value_types = AnsibleTagHelper.tag_types(value) | frozenset((type(value),)) + masked: set[type] = set() + + for ctx in reversed(self._notify_contexts): + if ctx._mask: + if (ctx_type := type(ctx)) in masked: + continue + + masked.add(ctx_type) + + # noinspection PyProtectedMember + if ctx._type_interest.intersection(value_types): + ctx._notify(value) diff --git a/lib/ansible/_internal/_templating/_chain_templar.py b/lib/ansible/_internal/_templating/_chain_templar.py new file mode 100644 index 00000000000..896dcc053aa --- /dev/null +++ b/lib/ansible/_internal/_templating/_chain_templar.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import collections.abc as c +import itertools +import typing as t + +from ansible.errors import AnsibleValueOmittedError, AnsibleError + +from ._engine import TemplateEngine + + +class ChainTemplar: + """A basic variable layering mechanism that supports templating and obliteration of `omit` values.""" + + def __init__(self, *sources: c.Mapping, templar: TemplateEngine) -> None: + self.sources = sources + self.templar = templar + + def template(self, key: t.Any, value: t.Any) -> t.Any: + """ + Render the given value using the templar. + Intended to be overridden by subclasses. + """ + return self.templar.template(value) + + def get(self, key: t.Any) -> t.Any: + """Get the value for the given key, templating the result before returning it.""" + for source in self.sources: + if key not in source: + continue + + value = source[key] + + try: + return self.template(key, value) + except AnsibleValueOmittedError: + break # omit == obliterate - matches historical behavior where dict layers were squashed before templating was applied + except Exception as ex: + raise AnsibleError(f'Error while resolving value for {key!r}.', obj=value) from ex + + raise KeyError(key) + + def keys(self) -> t.Iterable[t.Any]: + """ + Returns a sorted iterable of all keys present in all source layers, without templating associated values. + Values that resolve to `omit` are thus included. + """ + return sorted(set(itertools.chain.from_iterable(self.sources))) + + def items(self) -> t.Iterable[t.Tuple[t.Any, t.Any]]: + """ + Returns a sorted iterable of (key, templated value) tuples. + Any tuple where the templated value resolves to `omit` will not be included in the result. + """ + for key in self.keys(): + try: + yield key, self.get(key) + except KeyError: + pass + + def as_dict(self) -> dict[t.Any, t.Any]: + """Returns a dict representing all layers, squashed and templated, with `omit` values dropped.""" + return dict(self.items()) diff --git a/lib/ansible/_internal/_templating/_datatag.py b/lib/ansible/_internal/_templating/_datatag.py new file mode 100644 index 00000000000..a7696f8ba41 --- /dev/null +++ b/lib/ansible/_internal/_templating/_datatag.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import contextlib as _contextlib +import dataclasses +import typing as t + +from ansible.module_utils._internal._datatag import AnsibleSingletonTagBase, _tag_dataclass_kwargs +from ansible.module_utils._internal._datatag._tags import Deprecated +from ansible._internal._datatag._tags import Origin +from ansible.utils.display import Display + +from ._access import NotifiableAccessContextBase +from ._utils import TemplateContext + + +display = Display() + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class _JinjaConstTemplate(AnsibleSingletonTagBase): + # deprecated: description='embedded Jinja constant string template support' core_version='2.23' + pass + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _TrippedDeprecationInfo: + template: str + deprecated: Deprecated + + +class DeprecatedAccessAuditContext(NotifiableAccessContextBase): + """When active, captures metadata about managed accesses to `Deprecated` tagged objects.""" + + _type_interest = frozenset([Deprecated]) + + @classmethod + def when(cls, condition: bool, /) -> t.Self | _contextlib.nullcontext: + """Returns a new instance if `condition` is True (usually `TemplateContext.is_top_level`), otherwise a `nullcontext` instance.""" + if condition: + return cls() + + return _contextlib.nullcontext() + + def __init__(self) -> None: + self._tripped_deprecation_info: dict[int, _TrippedDeprecationInfo] = {} + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + result = super().__exit__(exc_type, exc_val, exc_tb) + + for item in self._tripped_deprecation_info.values(): + if Origin.is_tagged_on(item.template): + msg = item.deprecated.msg + else: + # without an origin, we need to include what context we do have (the template) + msg = f'While processing {item.template!r}: {item.deprecated.msg}' + + display._deprecated_with_plugin_info( + msg=msg, + help_text=item.deprecated.help_text, + version=item.deprecated.removal_version, + date=item.deprecated.removal_date, + obj=item.template, + plugin=item.deprecated.plugin, + ) + + return result + + def _notify(self, o: t.Any) -> None: + deprecated = Deprecated.get_required_tag(o) + deprecated_key = id(deprecated) + + if deprecated_key in self._tripped_deprecation_info: + return # record only the first access for each deprecated tag in a given context + + template_ctx = TemplateContext.current(optional=True) + template = template_ctx.template_value if template_ctx else None + + # when the current template input is a container, provide a descriptive string with origin propagated (if possible) + if not isinstance(template, str): + # DTFIX-FUTURE: ascend the template stack to try and find the nearest string source template + origin = Origin.get_tag(template) + + # DTFIX-RELEASE: this should probably use a synthesized description value on the tag + # it is reachable from the data_tagging_controller test: ../playbook_output_validator/filter.py actual_stdout.txt actual_stderr.txt + # -[DEPRECATION WARNING]: `something_old` is deprecated, don't use it! This feature will be removed in version 1.2.3. + # +[DEPRECATION WARNING]: While processing '<>': `something_old` is deprecated, don't use it! This feature will be removed in ... + template = '<>' + + if origin: + origin.tag(template) + + self._tripped_deprecation_info[deprecated_key] = _TrippedDeprecationInfo( + template=template, + deprecated=deprecated, + ) diff --git a/lib/ansible/_internal/_templating/_engine.py b/lib/ansible/_internal/_templating/_engine.py new file mode 100644 index 00000000000..b15c64e791c --- /dev/null +++ b/lib/ansible/_internal/_templating/_engine.py @@ -0,0 +1,588 @@ +# (c) 2012-2014, Michael DeHaan +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import annotations + +import copy +import dataclasses +import enum +import textwrap +import typing as t +import collections.abc as c +import re + +from collections import ChainMap + +from ansible.errors import ( + AnsibleError, + AnsibleValueOmittedError, + AnsibleUndefinedVariable, + AnsibleTemplateSyntaxError, + AnsibleBrokenConditionalError, + AnsibleTemplateTransformLimitError, + TemplateTrustCheckFailedError, +) + +from ansible.module_utils._internal._datatag import AnsibleTaggedObject, NotTaggableError, AnsibleTagHelper +from ansible._internal._errors._handler import Skippable +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate +from ansible.utils.display import Display +from ansible.utils.vars import validate_variable_name +from ansible.parsing.dataloader import DataLoader + +from ._datatag import DeprecatedAccessAuditContext +from ._jinja_bits import ( + AnsibleTemplate, + _TemplateCompileContext, + TemplateOverrides, + AnsibleEnvironment, + defer_template_error, + create_template_error, + is_possibly_template, + is_possibly_all_template, + AnsibleTemplateExpression, + _finalize_template_result, + FinalizeMode, +) +from ._jinja_common import _TemplateConfig, MarkerError, ExceptionMarker +from ._lazy_containers import _AnsibleLazyTemplateMixin +from ._marker_behaviors import MarkerBehavior, FAIL_ON_UNDEFINED +from ._transform import _type_transform_mapping +from ._utils import Omit, TemplateContext, IGNORE_SCALAR_VAR_TYPES, LazyOptions +from ...module_utils.datatag import native_type_name + +_display = Display() + + +_shared_empty_unmask_type_names: frozenset[str] = frozenset() + +TRANSFORM_CHAIN_LIMIT: int = 10 +"""Arbitrary limit for chained transforms to prevent cycles; an exception will be raised if exceeded.""" + + +class TemplateMode(enum.Enum): + # DTFIX-FUTURE: this enum ideally wouldn't exist - revisit/rename before making public + DEFAULT = enum.auto() + STOP_ON_TEMPLATE = enum.auto() + STOP_ON_CONTAINER = enum.auto() + ALWAYS_FINALIZE = enum.auto() + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class TemplateOptions: + DEFAULT: t.ClassVar[t.Self] + + value_for_omit: object = Omit + escape_backslashes: bool = True + preserve_trailing_newlines: bool = True + # DTFIX-RELEASE: these aren't really overrides anymore, rename the dataclass and this field + # also mention in docstring this has no effect unless used to template a string + overrides: TemplateOverrides = TemplateOverrides.DEFAULT + + +TemplateOptions.DEFAULT = TemplateOptions() + + +class TemplateEncountered(Exception): + pass + + +class TemplateEngine: + """ + The main class for templating, with the main entry-point of template(). + """ + + _sentinel = object() + + def __init__( + self, + loader: DataLoader | None = None, + variables: dict[str, t.Any] | ChainMap[str, t.Any] | None = None, + variables_factory: t.Callable[[], dict[str, t.Any] | ChainMap[str, t.Any]] | None = None, + marker_behavior: MarkerBehavior | None = None, + ): + self._loader = loader + self._variables = variables + self._variables_factory = variables_factory + self._environment: AnsibleEnvironment | None = None + + # inherit marker behavior from the active template context's templar unless otherwise specified + if not marker_behavior: + if template_ctx := TemplateContext.current(optional=True): + marker_behavior = template_ctx.templar.marker_behavior + else: + marker_behavior = FAIL_ON_UNDEFINED + + self._marker_behavior = marker_behavior + + def copy(self) -> t.Self: + new_engine = copy.copy(self) + new_engine._environment = None + + return new_engine + + def extend(self, marker_behavior: MarkerBehavior | None = None) -> t.Self: + # DTFIX-RELEASE: bikeshed name, supported features + new_templar = type(self)( + loader=self._loader, + variables=self._variables, + variables_factory=self._variables_factory, + marker_behavior=marker_behavior or self._marker_behavior, + ) + + if self._environment: + new_templar._environment = self._environment + + return new_templar + + @property + def marker_behavior(self) -> MarkerBehavior: + return self._marker_behavior + + @property + def basedir(self) -> str: + """The basedir from DataLoader.""" + return self._loader.get_basedir() if self._loader else '.' + + @property + def environment(self) -> AnsibleEnvironment: + if not self._environment: + self._environment = AnsibleEnvironment(ansible_basedir=self.basedir) + + return self._environment + + def _create_overlay(self, template: str, overrides: TemplateOverrides) -> tuple[str, AnsibleEnvironment]: + try: + template, overrides = overrides._extract_template_overrides(template) + except Exception as ex: + raise AnsibleTemplateSyntaxError("Syntax error in template.", obj=template) from ex + + env = self.environment + + if overrides is not TemplateOverrides.DEFAULT and (overlay_kwargs := overrides.overlay_kwargs()): + env = t.cast(AnsibleEnvironment, env.overlay(**overlay_kwargs)) + + return template, env + + @staticmethod + def _count_newlines_from_end(in_str): + """ + Counts the number of newlines at the end of a string. This is used during + the jinja2 templating to ensure the count matches the input, since some newlines + may be thrown away during the templating. + """ + + i = len(in_str) + j = i - 1 + + try: + while in_str[j] == '\n': + j -= 1 + except IndexError: + # Uncommon cases: zero length string and string containing only newlines + return i + + return i - 1 - j + + @property + def available_variables(self) -> dict[str, t.Any] | ChainMap[str, t.Any]: + """Available variables this instance will use when templating.""" + # DTFIX-RELEASE: ensure that we're always accessing this as a shallow container-level snapshot, and eliminate uses of anything + # that directly mutates this value. _new_context may resolve this for us? + if self._variables is None: + self._variables = self._variables_factory() if self._variables_factory else {} + + return self._variables + + @available_variables.setter + def available_variables(self, variables: dict[str, t.Any]) -> None: + self._variables = variables + + def resolve_variable_expression( + self, + expression: str, + *, + local_variables: dict[str, t.Any] | None = None, + ) -> t.Any: + """ + Resolve a potentially untrusted string variable expression consisting only of valid identifiers, integers, dots, and indexing containing these. + Optional local variables may be provided, which can only be referenced directly by the given expression. + Valid: x, x.y, x[y].z, x[1], 1, x[y.z] + Error: 'x', x['y'], q('env') + """ + components = re.split(r'[.\[\]]', expression) + + try: + for component in components: + if re.fullmatch('[0-9]*', component): + continue # allow empty strings and integers + + validate_variable_name(component) + except Exception as ex: + raise AnsibleError(f'Invalid variable expression: {expression}', obj=expression) from ex + + return self.evaluate_expression(TrustedAsTemplate().tag(expression), local_variables=local_variables) + + @staticmethod + def variable_name_as_template(name: str) -> str: + """Return a trusted template string that will resolve the provided variable name. Raises an error if `name` is not a valid identifier.""" + validate_variable_name(name) + return AnsibleTagHelper.tag('{{' + name + '}}', (AnsibleTagHelper.tags(name) | {TrustedAsTemplate()})) + + def transform(self, variable: t.Any) -> t.Any: + """Recursively apply transformations to the given value and return the result.""" + return self.template(variable, mode=TemplateMode.ALWAYS_FINALIZE, lazy_options=LazyOptions.SKIP_TEMPLATES_AND_ACCESS) + + def template( + self, + variable: t.Any, # DTFIX-RELEASE: once we settle the new/old API boundaries, rename this (here and in other methods) + *, + options: TemplateOptions = TemplateOptions.DEFAULT, + mode: TemplateMode = TemplateMode.DEFAULT, + lazy_options: LazyOptions = LazyOptions.DEFAULT, + ) -> t.Any: + """Templates (possibly recursively) any given data as input.""" + original_variable = variable + + for _attempt in range(TRANSFORM_CHAIN_LIMIT): + if variable is None or (value_type := type(variable)) in IGNORE_SCALAR_VAR_TYPES: + return variable # quickly ignore supported scalar types which are not be templated + + value_is_str = isinstance(variable, str) + + if template_ctx := TemplateContext.current(optional=True): + stop_on_template = template_ctx.stop_on_template + else: + stop_on_template = False + + if mode is TemplateMode.STOP_ON_TEMPLATE: + stop_on_template = True + + with ( + TemplateContext(template_value=variable, templar=self, options=options, stop_on_template=stop_on_template) as ctx, + DeprecatedAccessAuditContext.when(ctx.is_top_level), + ): + try: + if not value_is_str: + # transforms are currently limited to non-str types as an optimization + if (transform := _type_transform_mapping.get(value_type)) and value_type.__name__ not in lazy_options.unmask_type_names: + variable = transform(variable) + continue + + template_result = _AnsibleLazyTemplateMixin._try_create(variable, lazy_options) + elif not lazy_options.template: + template_result = variable + elif not is_possibly_template(variable, options.overrides): + template_result = variable + elif not self._trust_check(variable, skip_handler=stop_on_template): + template_result = variable + elif stop_on_template: + raise TemplateEncountered() + else: + compiled_template = self._compile_template(variable, options) + + template_result = compiled_template(self.available_variables) + template_result = self._post_render_mutation(variable, template_result, options) + except TemplateEncountered: + raise + except Exception as ex: + template_result = defer_template_error(ex, variable, is_expression=False) + + if ctx.is_top_level or mode is TemplateMode.ALWAYS_FINALIZE: + template_result = self._finalize_top_level_template_result( + variable, options, template_result, stop_on_container=mode is TemplateMode.STOP_ON_CONTAINER + ) + + return template_result + + raise AnsibleTemplateTransformLimitError(obj=original_variable) + + @staticmethod + def _finalize_top_level_template_result( + variable: t.Any, + options: TemplateOptions, + template_result: t.Any, + is_expression: bool = False, + stop_on_container: bool = False, + ) -> t.Any: + """ + This method must be called for expressions and top-level templates to recursively finalize the result. + This renders any embedded templates and triggers `Marker` and omit behaviors. + """ + try: + if template_result is Omit: + # When the template result is Omit, raise an AnsibleValueOmittedError if value_for_omit is Omit, otherwise return value_for_omit. + # Other occurrences of Omit will simply drop out of containers during _finalize_template_result. + if options.value_for_omit is Omit: + raise AnsibleValueOmittedError() + + return options.value_for_omit # trust that value_for_omit is an allowed type + + if stop_on_container and type(template_result) in AnsibleTaggedObject._collection_types: + # Use of stop_on_container implies the caller will perform necessary checks on values, + # most likely by passing them back into the templating system. + try: + return template_result._non_lazy_copy() + except AttributeError: + return template_result # non-lazy containers are returned as-is + + return _finalize_template_result(template_result, FinalizeMode.TOP_LEVEL) + except TemplateEncountered: + raise + except Exception as ex: + raise_from: BaseException + + if isinstance(ex, MarkerError): + exception_to_raise = ex.source._as_exception() + + # MarkerError is never suitable for use as the cause of another exception, it is merely a raiseable container for the source marker + # used for flow control (so its stack trace is rarely useful). However, if the source derives from a ExceptionMarker, its contained + # exception (previously raised) should be used as the cause. Other sources do not contain exceptions, so cannot provide a cause. + raise_from = exception_to_raise if isinstance(ex.source, ExceptionMarker) else None + else: + exception_to_raise = ex + raise_from = ex + + exception_to_raise = create_template_error(exception_to_raise, variable, is_expression) + + if exception_to_raise is ex: + raise # when the exception to raise is the active exception, just re-raise it + + if exception_to_raise is raise_from: + raise_from = exception_to_raise.__cause__ # preserve the exception's cause, if any, otherwise no cause will be used + + raise exception_to_raise from raise_from # always raise from something to avoid the currently active exception becoming __context__ + + def _compile_template(self, template: str, options: TemplateOptions) -> t.Callable[[c.Mapping[str, t.Any]], t.Any]: + # NOTE: Creating an overlay that lives only inside _compile_template means that overrides are not applied + # when templating nested variables, where Templar.environment is used, not the overlay. They are, however, + # applied to includes and imports. + try: + stripped_template, env = self._create_overlay(template, options.overrides) + + with _TemplateCompileContext(escape_backslashes=options.escape_backslashes): + return t.cast(AnsibleTemplate, env.from_string(stripped_template)) + except Exception as ex: + return self._defer_jinja_compile_error(ex, template, False) + + def _compile_expression(self, expression: str, options: TemplateOptions) -> t.Callable[[c.Mapping[str, t.Any]], t.Any]: + """ + Compile a Jinja expression, applying optional compile-time behavior via an environment overlay (if needed). The overlay is + necessary to avoid mutating settings on the Templar's shared environment, which could be visible to other code running concurrently. + In the specific case of escape_backslashes, the setting only applies to a top-level template at compile-time, not runtime, to + ensure that any nested template calls (e.g., include and import) do not inherit the (lack of) escaping behavior. + """ + try: + with _TemplateCompileContext(escape_backslashes=options.escape_backslashes): + return AnsibleTemplateExpression(self.environment.compile_expression(expression, False)) + except Exception as ex: + return self._defer_jinja_compile_error(ex, expression, True) + + def _defer_jinja_compile_error(self, ex: Exception, variable: str, is_expression: bool) -> t.Callable[[c.Mapping[str, t.Any]], t.Any]: + deferred_error = defer_template_error(ex, variable, is_expression=is_expression) + + def deferred_exception(_jinja_vars: c.Mapping[str, t.Any]) -> t.Any: + # a template/expression compile error always results in a single node representing the compile error + return self.marker_behavior.handle_marker(deferred_error) + + return deferred_exception + + def _post_render_mutation(self, template: str, result: t.Any, options: TemplateOptions) -> t.Any: + if options.preserve_trailing_newlines and isinstance(result, str): + # The low level calls above do not preserve the newline + # characters at the end of the input data, so we + # calculate the difference in newlines and append them + # to the resulting output for parity + # + # Using AnsibleEnvironment's keep_trailing_newline instead would + # result in change in behavior when trailing newlines + # would be kept also for included templates, for example: + # "Hello {% include 'world.txt' %}!" would render as + # "Hello world\n!\n" instead of "Hello world!\n". + data_newlines = self._count_newlines_from_end(template) + res_newlines = self._count_newlines_from_end(result) + + if data_newlines > res_newlines: + newlines = options.overrides.newline_sequence * (data_newlines - res_newlines) + result = AnsibleTagHelper.tag_copy(result, result + newlines) + + # If the input string template was source-tagged and the result is not, propagate the source tag to the new value. + # This provides further contextual information when a template-derived value/var causes an error. + if not Origin.is_tagged_on(result) and (origin := Origin.get_tag(template)): + try: + result = origin.tag(result) + except NotTaggableError: + pass # best effort- if we can't, oh well + + return result + + def is_template(self, data: t.Any, overrides: TemplateOverrides = TemplateOverrides.DEFAULT) -> bool: + """ + Evaluate the input data to determine if it contains a template, even if that template is invalid. Containers will be recursively searched. + Objects subject to template-time transforms that do not yield a template are not considered templates by this method. + Gating a conditional call to `template` with this method is redundant and inefficient -- request templating unconditionally instead. + """ + options = TemplateOptions(overrides=overrides) if overrides is not TemplateOverrides.DEFAULT else TemplateOptions.DEFAULT + + try: + self.template(data, options=options, mode=TemplateMode.STOP_ON_TEMPLATE) + except TemplateEncountered: + return True + else: + return False + + def resolve_to_container(self, variable: t.Any, options: TemplateOptions = TemplateOptions.DEFAULT) -> t.Any: + """ + Recursively resolve scalar string template input, stopping at the first container encountered (if any). + Used for e.g., partial templating of task arguments, where the plugin needs to handle final resolution of some args internally. + """ + return self.template(variable, options=options, mode=TemplateMode.STOP_ON_CONTAINER) + + def evaluate_expression( + self, + expression: str, + *, + local_variables: dict[str, t.Any] | None = None, + escape_backslashes: bool = True, + _render_jinja_const_template: bool = False, + ) -> t.Any: + """ + Evaluate a trusted string expression and return its result. + Optional local variables may be provided, which can only be referenced directly by the given expression. + """ + if not isinstance(expression, str): + raise TypeError(f"Expressions must be {str!r}, got {type(expression)!r}.") + + options = TemplateOptions(escape_backslashes=escape_backslashes, preserve_trailing_newlines=False) + + with ( + TemplateContext(template_value=expression, templar=self, options=options, _render_jinja_const_template=_render_jinja_const_template) as ctx, + DeprecatedAccessAuditContext.when(ctx.is_top_level), + ): + try: + if not TrustedAsTemplate.is_tagged_on(expression): + raise TemplateTrustCheckFailedError(obj=expression) + + template_variables = ChainMap(local_variables, self.available_variables) if local_variables else self.available_variables + compiled_template = self._compile_expression(expression, options) + + template_result = compiled_template(template_variables) + template_result = self._post_render_mutation(expression, template_result, options) + except Exception as ex: + template_result = defer_template_error(ex, expression, is_expression=True) + + return self._finalize_top_level_template_result(expression, options, template_result, is_expression=True) + + _BROKEN_CONDITIONAL_ALLOWED_FRAGMENT = 'Broken conditionals are currently allowed because the `ALLOW_BROKEN_CONDITIONALS` configuration option is enabled.' + _CONDITIONAL_AS_TEMPLATE_MSG = 'Conditionals should not be surrounded by templating delimiters such as {{ }} or {% %}.' + + def _strip_conditional_handle_empty(self, conditional) -> t.Any: + """ + Strips leading/trailing whitespace from the input expression. + If `ALLOW_BROKEN_CONDITIONALS` is enabled, None/empty is coerced to True (legacy behavior, deprecated). + Otherwise, None/empty results in a broken conditional error being raised. + """ + if isinstance(conditional, str): + # Leading/trailing whitespace on conditional expressions is not a problem, except we can't tell if the expression is empty (which *is* a problem). + # Always strip conditional input strings. Neither conditional expressions nor all-template conditionals have legit reasons to preserve + # surrounding whitespace, and they complicate detection and processing of all-template fallback cases. + conditional = AnsibleTagHelper.tag_copy(conditional, conditional.strip()) + + if conditional in (None, ''): + # deprecated backward-compatible behavior; None/empty input conditionals are always True + if _TemplateConfig.allow_broken_conditionals: + _display.deprecated( + msg='Empty conditional expression was evaluated as True.', + help_text=self._BROKEN_CONDITIONAL_ALLOWED_FRAGMENT, + obj=conditional, + version='2.23', + ) + + return True + + raise AnsibleBrokenConditionalError("Empty conditional expressions are not allowed.", obj=conditional) + + return conditional + + def _normalize_and_evaluate_conditional(self, conditional: str | bool) -> t.Any: + """Validate and normalize a conditional input value, resolving allowed embedded template cases and evaluating the resulting expression.""" + conditional = self._strip_conditional_handle_empty(conditional) + + # this must follow `_strip_conditional_handle_empty`, since None/empty are coerced to bool (deprecated) + if type(conditional) is bool: # pylint: disable=unidiomatic-typecheck + return conditional + + try: + if not isinstance(conditional, str): + if _TemplateConfig.allow_broken_conditionals: + # because the input isn't a string, the result will never be a bool; the broken conditional warning in the caller will apply on the result + return self.template(conditional, mode=TemplateMode.ALWAYS_FINALIZE) + + raise AnsibleBrokenConditionalError(message="Conditional expressions must be strings.", obj=conditional) + + if is_possibly_all_template(conditional): + # Indirection of trusted expressions is always allowed. If the expression appears to be entirely wrapped in template delimiters, + # we must resolve it. e.g. `when: "{{ some_var_resolving_to_a_trusted_expression_string }}"`. + # Some invalid meta-templating corner cases may sneak through here (e.g., `when: '{{ "foo" }} == {{ "bar" }}'`); these will + # result in an untrusted expression error. + result = self.template(conditional, mode=TemplateMode.ALWAYS_FINALIZE) + result = self._strip_conditional_handle_empty(result) + + if not isinstance(result, str): + _display.deprecated(msg=self._CONDITIONAL_AS_TEMPLATE_MSG, obj=conditional, version='2.23') + + return result # not an expression + + # The only allowed use of templates for conditionals is for indirect usage of an expression. + # Any other usage should simply be an expression, not an attempt at meta templating. + expression = result + else: + expression = conditional + + # Disable escape_backslashes when processing conditionals, to maintain backwards compatibility. + # This is necessary because conditionals were previously evaluated using {% %}, which was *NOT* affected by escape_backslashes. + # Now that conditionals use expressions, they would be affected by escape_backslashes if it was not disabled. + return self.evaluate_expression(expression, escape_backslashes=False, _render_jinja_const_template=True) + + except AnsibleUndefinedVariable as ex: + # DTFIX-FUTURE: we're only augmenting the message for context here; once we have proper contextual tracking, we can dump the re-raise + raise AnsibleUndefinedVariable("Error while evaluating conditional.", obj=conditional) from ex + + def evaluate_conditional(self, conditional: str | bool) -> bool: + """ + Evaluate a trusted string expression or boolean and return its boolean result. A non-boolean result will raise `AnsibleBrokenConditionalError`. + The ALLOW_BROKEN_CONDITIONALS configuration option can temporarily relax this requirement, allowing truthy conditionals to succeed. + """ + result = self._normalize_and_evaluate_conditional(conditional) + + if isinstance(result, bool): + return result + + bool_result = bool(result) + + msg = ( + f'Conditional result was {textwrap.shorten(str(result), width=40)!r} of type {native_type_name(result)!r}, ' + f'which evaluates to {bool_result}. Conditionals must have a boolean result.' + ) + + if _TemplateConfig.allow_broken_conditionals: + _display.deprecated(msg=msg, obj=conditional, help_text=self._BROKEN_CONDITIONAL_ALLOWED_FRAGMENT, version='2.23') + + return bool_result + + raise AnsibleBrokenConditionalError(msg, obj=conditional) + + @staticmethod + def _trust_check(value: str, skip_handler: bool = False) -> bool: + """ + Return True if the given value is trusted for templating, otherwise return False. + When the value is not trusted, a warning or error may be generated, depending on configuration. + """ + if TrustedAsTemplate.is_tagged_on(value): + return True + + if not skip_handler: + with Skippable, _TemplateConfig.untrusted_template_handler.handle(TemplateTrustCheckFailedError, skip_on_ignore=True): + raise TemplateTrustCheckFailedError(obj=value) + + return False diff --git a/lib/ansible/_internal/_templating/_errors.py b/lib/ansible/_internal/_templating/_errors.py new file mode 100644 index 00000000000..587b63f6b25 --- /dev/null +++ b/lib/ansible/_internal/_templating/_errors.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from ansible.errors import AnsibleTemplatePluginError + + +class AnsibleTemplatePluginRuntimeError(AnsibleTemplatePluginError): + """The specified template plugin (lookup/filter/test) raised an exception during execution.""" + + def __init__(self, plugin_type: str, plugin_name: str) -> None: + super().__init__(f'The {plugin_type} plugin {plugin_name!r} failed.') + + +class AnsibleTemplatePluginLoadError(AnsibleTemplatePluginError): + """The specified template plugin (lookup/filter/test) failed to load.""" + + def __init__(self, plugin_type: str, plugin_name: str) -> None: + super().__init__(f'The {plugin_type} plugin {plugin_name!r} failed to load.') + + +class AnsibleTemplatePluginNotFoundError(AnsibleTemplatePluginError, KeyError): + """ + The specified template plugin (lookup/filter/test) was not found. + This exception extends KeyError since Jinja filter/test resolution requires a KeyError to detect missing plugins. + Jinja compilation fails if a non-KeyError is raised for a missing filter/test, even if the plugin will not be invoked (inconsistent with stock Jinja). + """ + + def __init__(self, plugin_type: str, plugin_name: str) -> None: + super().__init__(f'The {plugin_type} plugin {plugin_name!r} was not found.') diff --git a/lib/ansible/_internal/_templating/_jinja_bits.py b/lib/ansible/_internal/_templating/_jinja_bits.py new file mode 100644 index 00000000000..4b05c8870ee --- /dev/null +++ b/lib/ansible/_internal/_templating/_jinja_bits.py @@ -0,0 +1,1066 @@ +from __future__ import annotations + +import ast +import collections.abc as c +import dataclasses +import enum +import pathlib +import tempfile +import types +import typing as t + +from collections import ChainMap + +import jinja2.nodes + +from jinja2 import pass_context, defaults, TemplateSyntaxError, FileSystemLoader +from jinja2.environment import Environment, Template, TemplateModule, TemplateExpression +from jinja2.compiler import Frame +from jinja2.lexer import TOKEN_VARIABLE_BEGIN, TOKEN_VARIABLE_END, TOKEN_STRING, Lexer +from jinja2.nativetypes import NativeCodeGenerator +from jinja2.nodes import Const, EvalContext +from jinja2.runtime import Context +from jinja2.sandbox import ImmutableSandboxedEnvironment +from jinja2.utils import missing, LRUCache + +from ansible.utils.display import Display +from ansible.errors import AnsibleVariableTypeError, AnsibleTemplateSyntaxError, AnsibleTemplateError +from ansible.module_utils.common.text.converters import to_text +from ansible.module_utils._internal._datatag import ( + _AnsibleTaggedDict, + _AnsibleTaggedList, + _AnsibleTaggedTuple, + _AnsibleTaggedStr, + AnsibleTagHelper, +) + +from ansible._internal._errors._handler import ErrorAction +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate + +from ._access import AnsibleAccessContext +from ._datatag import _JinjaConstTemplate +from ._utils import LazyOptions +from ._jinja_common import ( + MarkerError, + Marker, + CapturedExceptionMarker, + UndefinedMarker, + _TemplateConfig, + TruncationMarker, + validate_arg_type, + JinjaCallContext, +) +from ._jinja_plugins import JinjaPluginIntercept, _query, _lookup, _now, _wrap_plugin_output, get_first_marker_arg, _DirectCall, _jinja_const_template_warning +from ._lazy_containers import ( + _AnsibleLazyTemplateMixin, + _AnsibleLazyTemplateDict, + _AnsibleLazyTemplateList, + _AnsibleLazyAccessTuple, + lazify_container_args, + lazify_container_kwargs, + lazify_container, + register_known_types, +) +from ._utils import Omit, TemplateContext, PASS_THROUGH_SCALAR_VAR_TYPES + +from ansible.module_utils._internal._json._profiles import _json_subclassable_scalar_types +from ansible.module_utils import _internal +from ansible.module_utils._internal import _ambient_context, _dataclass_validation +from ansible.plugins.loader import filter_loader, test_loader +from ansible.vars.hostvars import HostVars, HostVarsVars +from ...module_utils.datatag import native_type_name + +JINJA2_OVERRIDE = '#jinja2:' + +display = Display() + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class TemplateOverrides: + DEFAULT: t.ClassVar[t.Self] + + block_start_string: str = defaults.BLOCK_START_STRING + block_end_string: str = defaults.BLOCK_END_STRING + variable_start_string: str = defaults.VARIABLE_START_STRING + variable_end_string: str = defaults.VARIABLE_END_STRING + comment_start_string: str = defaults.COMMENT_START_STRING + comment_end_string: str = defaults.COMMENT_END_STRING + line_statement_prefix: str | None = defaults.LINE_STATEMENT_PREFIX + line_comment_prefix: str | None = defaults.LINE_COMMENT_PREFIX + trim_blocks: bool = True # AnsibleEnvironment overrides this default, so don't use the Jinja default here + lstrip_blocks: bool = defaults.LSTRIP_BLOCKS + newline_sequence: t.Literal['\n', '\r\n', '\r'] = defaults.NEWLINE_SEQUENCE + keep_trailing_newline: bool = defaults.KEEP_TRAILING_NEWLINE + + def __post_init__(self) -> None: + pass # overridden by _dataclass_validation._inject_post_init_validation + + def _post_validate(self) -> None: + if not (self.block_start_string != self.variable_start_string != self.comment_start_string != self.block_start_string): + raise ValueError('Block, variable and comment start strings must be different.') + + def overlay_kwargs(self) -> dict[str, t.Any]: + """ + Return a dictionary of arguments for passing to Environment.overlay. + The dictionary will be empty if all fields have their default value. + """ + # DTFIX-FUTURE: calculate default/non-default during __post_init__ + fields = [(field, getattr(self, field.name)) for field in dataclasses.fields(self)] + kwargs = {field.name: value for field, value in fields if value != field.default} + + return kwargs + + def _contains_start_string(self, value: str) -> bool: + """Returns True if the given value contains a variable, block or comment start string.""" + # DTFIX-FUTURE: this is inefficient, use a compiled regex instead + + for marker in (self.block_start_string, self.variable_start_string, self.comment_start_string): + if marker in value: + return True + + return False + + def _starts_and_ends_with_jinja_delimiters(self, value: str) -> bool: + """Returns True if the given value starts and ends with Jinja variable, block or comment delimiters.""" + # DTFIX-FUTURE: this is inefficient, use a compiled regex instead + + for marker in (self.block_start_string, self.variable_start_string, self.comment_start_string): + if value.startswith(marker): + break + else: + return False + + for marker in (self.block_end_string, self.variable_end_string, self.comment_end_string): + if value.endswith(marker): + return True + + return False + + def _extract_template_overrides(self, template: str) -> tuple[str, TemplateOverrides]: + if template.startswith(JINJA2_OVERRIDE): + eol = template.find('\n') + + if eol == -1: + raise ValueError(f"Missing newline after {JINJA2_OVERRIDE!r} override.") + + line = template[len(JINJA2_OVERRIDE) : eol] + template = template[eol + 1 :] + override_kwargs = {} + + for pair in line.split(','): + if not pair.strip(): + raise ValueError(f"Empty {JINJA2_OVERRIDE!r} override pair not allowed.") + + if ':' not in pair: + raise ValueError(f"Missing key-value separator `:` in {JINJA2_OVERRIDE!r} override pair {pair!r}.") + + key, val = pair.split(':', 1) + key = key.strip() + + if key not in _TEMPLATE_OVERRIDE_FIELD_NAMES: + raise ValueError(f"Invalid {JINJA2_OVERRIDE!r} override key {key!r}.") + + override_kwargs[key] = ast.literal_eval(val) + + overrides = dataclasses.replace(self, **override_kwargs) + else: + overrides = self + + return template, overrides + + def merge(self, kwargs: dict[str, t.Any] | None, /) -> TemplateOverrides: + """Return a new instance based on the current instance with the given kwargs overridden.""" + if kwargs: + return self.from_kwargs(dataclasses.asdict(self) | kwargs) + + return self + + @classmethod + def from_kwargs(cls, kwargs: dict[str, t.Any] | None, /) -> TemplateOverrides: + """TemplateOverrides instance factory; instances resolving to all default values will instead return the DEFAULT singleton for optimization.""" + if kwargs: + value = cls(**kwargs) + + if value.overlay_kwargs(): + return value + + return cls.DEFAULT + + +_dataclass_validation.inject_post_init_validation(TemplateOverrides, allow_subclasses=True) + +TemplateOverrides.DEFAULT = TemplateOverrides() + +_TEMPLATE_OVERRIDE_FIELD_NAMES: t.Final[tuple[str, ...]] = tuple(sorted(field.name for field in dataclasses.fields(TemplateOverrides))) + + +class AnsibleContext(Context): + """ + A custom context which intercepts resolve_or_missing() calls and + runs them through AnsibleAccessContext. This allows usage of variables + to be tracked. If needed, values can also be modified before being returned. + """ + + environment: AnsibleEnvironment # narrow the type specified by the base + + def __init__(self, *args, **kwargs): + super(AnsibleContext, self).__init__(*args, **kwargs) + + __repr__ = object.__repr__ # prevent Jinja from dumping vars in case this gets repr'd + + def get_all(self): + """ + Override Jinja's default get_all to return all vars in the context as a ChainMap with a mutable layer at the bottom. + This provides some isolation against accidental changes to inherited variable contexts without requiring copies. + """ + layers = [] + + if self.vars: + layers.append(self.vars) + if self.parent: + layers.append(self.parent) + + # HACK: always include a sacrificial plain-dict on the bottom layer, since Jinja's debug and stacktrace rewrite code invokes + # `__setitem__` outside a call context; this will ensure that it always occurs on a plain dict instead of a lazy one. + return ChainMap({}, *layers) + + # noinspection PyShadowingBuiltins + def derived(self, locals: t.Optional[t.Dict[str, t.Any]] = None) -> Context: + # this is a clone of Jinja's impl of derived, but using our lazy-aware _new_context + + context = _new_context( + environment=self.environment, + template_name=self.name, + blocks={}, + shared=True, + jinja_locals=locals, + jinja_vars=self.get_all(), + ) + context.eval_ctx = self.eval_ctx + context.blocks.update((k, list(v)) for k, v in self.blocks.items()) + return context + + def keys(self, *args, **kwargs): + """Base Context delegates to `dict.keys` against `get_all`, which would fail since we return a ChainMap. No known usage.""" + raise NotImplementedError() + + def values(self, *args, **kwargs): + """Base Context delegates to `dict.values` against `get_all`, which would fail since we return a ChainMap. No known usage.""" + raise NotImplementedError() + + def items(self, *args, **kwargs): + """Base Context delegates to built-in `dict.items` against `get_all`, which would fail since we return a ChainMap. No known usage.""" + raise NotImplementedError() + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class ArgSmuggler: + """ + Utility wrapper to wrap/unwrap args passed to Jinja `Template.render` and `TemplateExpression.__call__`. + e.g., see https://github.com/pallets/jinja/blob/3.1.3/src/jinja2/environment.py#L1296 and + https://github.com/pallets/jinja/blob/3.1.3/src/jinja2/environment.py#L1566. + """ + + jinja_vars: c.Mapping[str, t.Any] | None + + @classmethod + def package_jinja_vars(cls, jinja_vars: c.Mapping[str, t.Any]) -> dict[str, ArgSmuggler]: + """Wrap the supplied vars dict in an ArgSmuggler to prevent premature templating from Jinja's internal dict copy.""" + return dict(_smuggled_vars=ArgSmuggler(jinja_vars=jinja_vars)) + + @classmethod + def extract_jinja_vars(cls, maybe_smuggled_vars: c.Mapping[str, t.Any] | None) -> c.Mapping[str, t.Any]: + """ + If the supplied vars dict contains an ArgSmuggler instance with the expected key, unwrap it and return the smuggled value. + Otherwise, return the supplied dict as-is. + """ + if maybe_smuggled_vars and ((smuggler := maybe_smuggled_vars.get('_smuggled_vars')) and isinstance(smuggler, ArgSmuggler)): + return smuggler.jinja_vars + + return maybe_smuggled_vars + + +class AnsibleTemplateExpression: + """ + Wrapper around Jinja's TemplateExpression for converting MarkerError back into Marker. + This is needed to make expression error handling consistent with templates, since Jinja does not support a custom type for Environment.compile_expression. + """ + + def __init__(self, template_expression: TemplateExpression) -> None: + self._template_expression = template_expression + + def __call__(self, jinja_vars: c.Mapping[str, t.Any]) -> t.Any: + try: + return self._template_expression(ArgSmuggler.package_jinja_vars(jinja_vars)) + except MarkerError as ex: + return ex.source + + +class AnsibleTemplate(Template): + """ + A helper class, which prevents Jinja2 from running lazy containers through dict(). + """ + + _python_source_temp_path: pathlib.Path | None = None + + def __del__(self): + # DTFIX-RELEASE: this still isn't working reliably; something else must be keeping the template object alive + if self._python_source_temp_path: + self._python_source_temp_path.unlink(missing_ok=True) + + def __call__(self, jinja_vars: c.Mapping[str, t.Any]) -> t.Any: + return self.render(ArgSmuggler.package_jinja_vars(jinja_vars)) + + # noinspection PyShadowingBuiltins + def new_context( + self, + vars: c.Mapping[str, t.Any] | None = None, + shared: bool = False, + locals: c.Mapping[str, t.Any] | None = None, + ) -> Context: + return _new_context( + environment=self.environment, + template_name=self.name, + blocks=self.blocks, + shared=shared, + jinja_locals=locals, + jinja_vars=ArgSmuggler.extract_jinja_vars(vars), + jinja_globals=self.globals, + ) + + +class AnsibleCodeGenerator(NativeCodeGenerator): + """ + Custom code generation behavior to support deprecated Ansible features and fill in gaps in Jinja native. + This can be removed once the deprecated Ansible features are removed and the native fixes are upstreamed in Jinja. + """ + + def _output_const_repr(self, group: t.Iterable[t.Any]) -> str: + """ + Prevent Jinja's code generation from stringifying single nodes before generating its repr. + This complements the behavioral change in AnsibleEnvironment.concat which returns single nodes without stringifying them. + """ + # DTFIX-FUTURE: contribute this upstream as a fix to Jinja's native support + group_list = list(group) + + if len(group_list) == 1: + return repr(group_list[0]) + + # NB: This is slightly more efficient than Jinja's _output_const_repr, which generates a throw-away list instance to pass to join. + # Before removing this, ensure that upstream Jinja has this change. + return repr("".join(map(str, group_list))) + + def visit_Const(self, node: Const, frame: Frame) -> None: + """ + Override Jinja's visit_Const to inject a runtime call to AnsibleEnvironment._access_const for constant strings that are possibly templates, which + may require special handling at runtime. See that method for details. An example that hits this path: + {{ lookup("file", "{{ output_dir }}/bla") }} + """ + value = node.as_const(frame.eval_ctx) + + if _TemplateConfig.allow_embedded_templates and type(value) is str and is_possibly_template(value): # pylint: disable=unidiomatic-typecheck + # deprecated: description='embedded Jinja constant string template support' core_version='2.23' + self.write(f'environment._access_const({value!r})') + else: + # NB: This is actually more efficient than Jinja's visit_Const, which contains obsolete (as of Py2.7/3.1) float conversion instance checks. Before + # removing this override entirely, ensure that upstream Jinja has removed the obsolete code. + # See https://docs.python.org/release/2.7/whatsnew/2.7.html#python-3-1-features for more details. + self.write(repr(value)) + + +@pass_context +def _ansible_finalize(_ctx: AnsibleContext, value: t.Any) -> t.Any: + """ + This function is called by Jinja with the result of each variable template block (e.g., {{ }}) encountered in a template. + The pass_context decorator prevents finalize from being called on constants at template compile time. + The passed in AnsibleContext is unused -- it is the result of using the pass_context decorator. + The important part for us is that this blocks constant folding, which ensures our custom visit_Const is used. + It also ensures that template results are wrapped in lazy containers. + """ + return lazify_container(value) + + +@dataclasses.dataclass(kw_only=True, slots=True) +class _TemplateCompileContext(_ambient_context.AmbientContextBase): + """ + This context is active during Ansible's explicit compilation of templates/expressions, but not during Jinja's runtime compilation. + Historically, Ansible-specific pre-processing like `escape_backslashes` was not applied to imported/included templates. + """ + + escape_backslashes: bool + + +class _CompileStateSmugglingCtx(_ambient_context.AmbientContextBase): + template_source: str | None = None + python_source: str | None = None + python_source_temp_path: pathlib.Path | None = None + + +class AnsibleLexer(Lexer): + """ + Lexer override to escape backslashes in string constants within Jinja expressions; prevents Jinja from double-escaping them. + + NOTE: This behavior is only applied to string constants within Jinja expressions (eg {{ "c:\newfile" }}), *not* statements ("{% set foo="c:\\newfile" %}"). + + This is useful when templates are sourced from YAML double-quoted strings, as it avoids having backslashes processed twice: first by the + YAML parser, and then again by the Jinja parser. Instead, backslashes are only processed by YAML. + + Example YAML: + + - debug: + msg: "Test Case 1\\3; {{ test1_name | regex_replace('^(.*)_name$', '\\1')}}" + + Since the outermost YAML string is double-quoted, the YAML parser converts the double backslashes to single backslashes. Without escaping, Jinja + would see only a single backslash ('\1') while processing the embedded template expression, interpret it as an escape sequence, and convert it + to '\x01' (ASCII "SOH"). This is clearly not the intended `\1` backreference argument to the `regex_replace` filter (which would require the + double-escaped string '\\\\1' to yield the intended result). + + Since the "\\3" in the input YAML was not part of a template expression, the YAML-parsed "\3" remains after Jinja rendering. This would be + confusing for playbook authors, as different escaping rules would be needed inside and outside the template expression. + + When templates are not sourced from YAML, escaping backslashes will prevent use of backslash escape sequences such as "\n" and "\t". + + See relevant Jinja lexer impl at e.g.: https://github.com/pallets/jinja/blob/3.1.2/src/jinja2/lexer.py#L646-L653. + """ + + def tokeniter(self, *args, **kwargs) -> t.Iterator[t.Tuple[int, str, str]]: + """Pre-escape backslashes in expression ({{ }}) raw string constants before Jinja's Lexer.wrap() can interpret them as ASCII escape sequences.""" + token_stream = super().tokeniter(*args, **kwargs) + + # if we have no context, Jinja's doing a nested compile at runtime (eg, import/include); historically, no backslash escaping is performed + if not (tcc := _TemplateCompileContext.current(optional=True)) or not tcc.escape_backslashes: + yield from token_stream + return + + in_variable = False + + for token in token_stream: + token_type = token[1] + + if token_type == TOKEN_VARIABLE_BEGIN: + in_variable = True + elif token_type == TOKEN_VARIABLE_END: + in_variable = False + elif in_variable and token_type == TOKEN_STRING: + token = token[0], token_type, token[2].replace('\\', '\\\\') + + yield token + + +def defer_template_error(ex: Exception, variable: t.Any, *, is_expression: bool) -> Marker: + if not ex.__traceback__: + raise AssertionError('ex must be a previously raised exception') + + if isinstance(ex, MarkerError): + return ex.source + + exception_to_raise = create_template_error(ex, variable, is_expression) + + if exception_to_raise is ex: + return CapturedExceptionMarker(ex) # capture the previously raised exception + + try: + raise exception_to_raise from ex # raise the newly synthesized exception before capturing it + except Exception as captured_ex: + return CapturedExceptionMarker(captured_ex) + + +def create_template_error(ex: Exception, variable: t.Any, is_expression: bool) -> AnsibleTemplateError: + if isinstance(ex, AnsibleTemplateError): + exception_to_raise = ex + else: + kind = "expression" if is_expression else "template" + ex_type = AnsibleTemplateError # always raise an AnsibleTemplateError/subclass + + if isinstance(ex, RecursionError): + msg = f"Recursive loop detected in {kind}." + elif isinstance(ex, TemplateSyntaxError): + msg = f"Syntax error in {kind}." + + if is_expression and is_possibly_template(variable): + msg += " Template delimiters are not supported in expressions." + + ex_type = AnsibleTemplateSyntaxError + else: + msg = f"Error rendering {kind}." + + exception_to_raise = ex_type(msg, obj=variable) + + if exception_to_raise.obj is None: + exception_to_raise.obj = TemplateContext.current().template_value + + # DTFIX-FUTURE: Look through the TemplateContext hierarchy to find the most recent non-template + # caller and use that for origin when no origin is available on obj. This could be useful for situations where the template + # was embedded in a plugin, or a plugin is otherwise responsible for losing the origin and/or trust. We can't just use the first + # non-template caller as that will lead to false positives for re-entrant calls (e.g. template plugins that call into templar). + + return exception_to_raise + + +# DTFIX-RELEASE: implement CapturedExceptionMarker deferral support on call (and lookup), filter/test plugins, etc. +# also update the protomatter integration test once this is done (the test was written differently since this wasn't done yet) + +_BUILTIN_FILTER_ALIASES: dict[str, str] = {} +_BUILTIN_TEST_ALIASES: dict[str, str] = { + '!=': 'ne', + '<': 'lt', + '<=': 'le', + '==': 'eq', + '>': 'gt', + '>=': 'ge', +} + +_BUILTIN_FILTERS = filter_loader._wrap_funcs(defaults.DEFAULT_FILTERS, _BUILTIN_FILTER_ALIASES) +_BUILTIN_TESTS = test_loader._wrap_funcs(t.cast(dict[str, t.Callable], defaults.DEFAULT_TESTS), _BUILTIN_TEST_ALIASES) + + +class AnsibleEnvironment(ImmutableSandboxedEnvironment): + """ + Our custom environment, which simply allows us to override the class-level + values for the Template and Context classes used by jinja2 internally. + """ + + context_class = AnsibleContext + template_class = AnsibleTemplate + code_generator_class = AnsibleCodeGenerator + intercepted_binops = frozenset(('eq',)) + + _lexer_cache = LRUCache(50) + + # DTFIX-FUTURE: bikeshed a name/mechanism to control template debugging + _debuggable_template_source = False + _debuggable_template_source_path: pathlib.Path = pathlib.Path(__file__).parent.parent.parent.parent / '.template_debug_source' + + def __init__(self, *args, ansible_basedir: str | None = None, **kwargs) -> None: + if ansible_basedir: + kwargs.update(loader=FileSystemLoader(ansible_basedir)) + + super().__init__(*args, extensions=_TemplateConfig.jinja_extensions, **kwargs) + + self.filters = JinjaPluginIntercept(_BUILTIN_FILTERS, filter_loader) # type: ignore[assignment] + self.tests = JinjaPluginIntercept(_BUILTIN_TESTS, test_loader) # type: ignore[assignment,arg-type] + + # future Jinja releases may default-enable autoescape; force-disable to prevent the problems it could cause + # see https://github.com/pallets/jinja/blob/3.1.2/docs/api.rst?plain=1#L69 + self.autoescape = False + + self.trim_blocks = True + + self.undefined = UndefinedMarker + self.finalize = _ansible_finalize + + self.globals.update( + range=range, # the sandboxed environment limits range in ways that may cause us problems; use the real Python one + now=_now, + undef=_undef, + omit=Omit, + lookup=_lookup, + query=_query, + q=_query, + ) + + # Disabling the optimizer prevents compile-time constant expression folding, which prevents our + # visit_Const recursive inline template expansion tricks from working in many cases where Jinja's + # ignorance of our embedded templates are optimized away as fully-constant expressions, + # eg {{ "{{'hi'}}" == "hi" }}. As of Jinja ~3.1, this specifically avoids cases where the @optimizeconst + # visitor decorator performs constant folding, which bypasses our visit_Const impl and causes embedded + # templates to be lost. + # See also optimizeconst impl: https://github.com/pallets/jinja/blob/3.1.0/src/jinja2/compiler.py#L48-L49 + self.optimized = False + + def get_template( + self, + name: str | Template, + parent: str | None = None, + globals: c.MutableMapping[str, t.Any] | None = None, + ) -> Template: + """Ensures that templates built via `get_template` are also source debuggable.""" + with _CompileStateSmugglingCtx.when(self._debuggable_template_source) as ctx: + template_obj = t.cast(AnsibleTemplate, super().get_template(name, parent, globals)) + + if isinstance(ctx, _CompileStateSmugglingCtx): # only present if debugging is enabled + template_obj._python_source_temp_path = ctx.python_source_temp_path # facilitate deletion of the temp file when template_obj is deleted + + return template_obj + + @property + def lexer(self) -> AnsibleLexer: + """Return/cache an AnsibleLexer with settings from the current AnsibleEnvironment""" + # DTFIX-RELEASE: optimization - we should pre-generate the default cached lexer before forking, not leave it to chance (e.g. simple playbooks) + key = tuple(getattr(self, name) for name in _TEMPLATE_OVERRIDE_FIELD_NAMES) + + lex = self._lexer_cache.get(key) + + if lex is None: + self._lexer_cache[key] = lex = AnsibleLexer(self) + + return lex + + def call_filter( + self, + name: str, + value: t.Any, + args: c.Sequence[t.Any] | None = None, + kwargs: c.Mapping[str, t.Any] | None = None, + context: Context | None = None, + eval_ctx: EvalContext | None = None, + ) -> t.Any: + """ + Ensure that filters directly invoked by plugins will see non-templating lazy containers. + Without this, `_wrap_filter` will wrap `args` and `kwargs` in templating lazy containers. + This provides consistency with plugin output handling by preventing auto-templating of trusted templates passed in native containers. + """ + # DTFIX-RELEASE: need better logic to handle non-list/non-dict inputs for args/kwargs + args = _AnsibleLazyTemplateMixin._try_create(list(args or []), LazyOptions.SKIP_TEMPLATES) + kwargs = _AnsibleLazyTemplateMixin._try_create(kwargs, LazyOptions.SKIP_TEMPLATES) + + return super().call_filter(name, value, args, kwargs, context, eval_ctx) + + def call_test( + self, + name: str, + value: t.Any, + args: c.Sequence[t.Any] | None = None, + kwargs: c.Mapping[str, t.Any] | None = None, + context: Context | None = None, + eval_ctx: EvalContext | None = None, + ) -> t.Any: + """ + Ensure that tests directly invoked by plugins will see non-templating lazy containers. + Without this, `_wrap_test` will wrap `args` and `kwargs` in templating lazy containers. + This provides consistency with plugin output handling by preventing auto-templating of trusted templates passed in native containers. + """ + # DTFIX-RELEASE: need better logic to handle non-list/non-dict inputs for args/kwargs + args = _AnsibleLazyTemplateMixin._try_create(list(args or []), LazyOptions.SKIP_TEMPLATES) + kwargs = _AnsibleLazyTemplateMixin._try_create(kwargs, LazyOptions.SKIP_TEMPLATES) + + return super().call_test(name, value, args, kwargs, context, eval_ctx) + + def compile_expression(self, source: str, *args, **kwargs) -> TemplateExpression: + # compile_expression parses and passes the tree to from_string; for debug support, activate the context here to capture the intermediate results + with _CompileStateSmugglingCtx.when(self._debuggable_template_source) as ctx: + if isinstance(ctx, _CompileStateSmugglingCtx): # only present if debugging is enabled + ctx.template_source = source + + return super().compile_expression(source, *args, **kwargs) + + def from_string(self, source: str | jinja2.nodes.Template, *args, **kwargs) -> AnsibleTemplate: + # if debugging is enabled, use existing context when present (e.g., from compile_expression) + current_ctx = _CompileStateSmugglingCtx.current(optional=True) if self._debuggable_template_source else None + + with _CompileStateSmugglingCtx.when(self._debuggable_template_source and not current_ctx) as new_ctx: + template_obj = t.cast(AnsibleTemplate, super().from_string(source, *args, **kwargs)) + + if isinstance(ctx := current_ctx or new_ctx, _CompileStateSmugglingCtx): # only present if debugging is enabled + template_obj._python_source_temp_path = ctx.python_source_temp_path # facilitate deletion of the temp file when template_obj is deleted + + return template_obj + + def _parse(self, source: str, *args, **kwargs) -> jinja2.nodes.Template: + if csc := _CompileStateSmugglingCtx.current(optional=True): + csc.template_source = source + + return super()._parse(source, *args, **kwargs) + + def _compile(self, source: str, filename: str) -> types.CodeType: + if csc := _CompileStateSmugglingCtx.current(optional=True): + origin = Origin.get_tag(csc.template_source) or Origin.UNKNOWN + + source = '\n'.join( + ( + "import sys; breakpoint() if type(sys.breakpointhook) is not type(breakpoint) else None", + f"# original template source from {str(origin)!r}: ", + '\n'.join(f'# {line}' for line in (csc.template_source or '').splitlines()), + source, + ) + ) + + source_temp_dir = self._debuggable_template_source_path + source_temp_dir.mkdir(parents=True, exist_ok=True) + + with tempfile.NamedTemporaryFile(dir=source_temp_dir, mode='w', suffix='.py', prefix='j2_src_', delete=False) as source_file: + filename = source_file.name + + source_file.write(source) + source_file.flush() + + csc.python_source = source + csc.python_source_temp_path = pathlib.Path(filename) + + res = super()._compile(source, filename) + + return res + + @staticmethod + def concat(nodes: t.Iterable[t.Any]) -> t.Any: # type: ignore[override] + node_list = list(_flatten_nodes(nodes)) + + if not node_list: + return None + + # this code is complemented by our tweaked CodeGenerator _output_const_repr that ensures that literal constants + # in templates aren't double-repr'd in the generated code + if len(node_list) == 1: + # DTFIX-RELEASE: determine if we should do managed access here (we *should* have hit them all during templating/resolve, but ?) + return node_list[0] + + # In order to ensure that all markers are tripped, do a recursive finalize before we repr (otherwise we can end up + # repr'ing a Marker). This requires two passes, but avoids the need for a parallel reimplementation of all repr methods. + try: + node_list = _finalize_template_result(node_list, FinalizeMode.CONCAT) + except MarkerError as ex: + return ex.source # return the first Marker encountered + + return ''.join([to_text(v) for v in node_list]) + + @staticmethod + def _access_const(const_template: t.LiteralString) -> t.Any: + """ + Called during template rendering on template-looking string constants embedded in the template. + It provides the following functionality: + * Propagates origin from the containing template. + * For backward compatibility when embedded templates are enabled: + * Conditionals - Renders embedded template constants and accesses the result. Warns on each constant immediately. + * Non-conditionals - Tags constants for deferred rendering of templates in lookup terms. Warns on each constant during lookup invocation. + """ + ctx = TemplateContext.current() + + if (tv := ctx.template_value) and (origin := Origin.get_tag(tv)): + const_template = origin.tag(const_template) + + if ctx._render_jinja_const_template: + _jinja_const_template_warning(const_template, is_conditional=True) + + result = ctx.templar.template(TrustedAsTemplate().tag(const_template)) + AnsibleAccessContext.current().access(result) + else: + # warnings will be issued when lookup terms processing occurs, to avoid false positives + result = _JinjaConstTemplate().tag(const_template) + + return result + + def getitem(self, obj: t.Any, argument: t.Any) -> t.Any: + value = super().getitem(obj, argument) + + AnsibleAccessContext.current().access(value) + + return value + + def getattr(self, obj: t.Any, attribute: str) -> t.Any: + """ + Get `attribute` from the attributes of `obj`, falling back to items in `obj`. + If no item was found, return a sandbox-specific `UndefinedMarker` if `attribute` is protected by the sandbox, + otherwise return a normal `UndefinedMarker` instance. + This differs from the built-in Jinja behavior which will not fall back to items if `attribute` is protected by the sandbox. + """ + # example template that uses this: "{{ some.thing }}" -- obj is the "some" dict, attribute is "thing" + + is_safe = True + + try: + value = getattr(obj, attribute) + except AttributeError: + value = _sentinel + else: + if not (is_safe := self.is_safe_attribute(obj, attribute, value)): + value = _sentinel + + if value is _sentinel: + try: + value = obj[attribute] + except (TypeError, LookupError): + return self.undefined(obj=obj, name=attribute) if is_safe else self.unsafe_undefined(obj, attribute) + + AnsibleAccessContext.current().access(value) + + return value + + def call( + self, + __context: Context, + __obj: t.Any, + *args: t.Any, + **kwargs: t.Any, + ) -> t.Any: + if _DirectCall.is_marked(__obj): + # Both `_lookup` and `_query` handle arg proxying and `Marker` args internally. + # Performing either before calling them will interfere with that processing. + return super().call(__context, __obj, *args, **kwargs) + + if (first_marker := get_first_marker_arg(args, kwargs)) is not None: + return first_marker + + try: + with JinjaCallContext(accept_lazy_markers=False): + call_res = super().call(__context, __obj, *lazify_container_args(args), **lazify_container_kwargs(kwargs)) + + if __obj is range: + # Preserve the ability to do `range(1000000000) | random` by not converting range objects to lists. + # Historically, range objects were only converted on Jinja finalize and filter outputs, so they've always been floating around in templating + # code and visible to user plugins. + return call_res + + return _wrap_plugin_output(call_res) + + except MarkerError as ex: + return ex.source + + +AnsibleTemplate.environment_class = AnsibleEnvironment + +_DEFAULT_UNDEF = UndefinedMarker("Mandatory variable has not been overridden", _no_template_source=True) + +_sentinel: t.Final[object] = object() + + +@_DirectCall.mark +def _undef(hint=None): + """Jinja2 global function (undef) for creating getting a `UndefinedMarker` instance, optionally with a custom hint.""" + validate_arg_type('hint', hint, (str, type(None))) + + if not hint: + return _DEFAULT_UNDEF + + return UndefinedMarker(hint) + + +def _flatten_nodes(nodes: t.Iterable[t.Any]) -> t.Iterable[t.Any]: + """ + Yield nodes from a potentially recursive iterable of nodes. + The recursion is required to expand template imports (TemplateModule). + Any exception raised while consuming a template node will be yielded as a Marker for that node. + """ + iterator = iter(nodes) + + while True: + try: + node = next(iterator) + except StopIteration: + break + except Exception as ex: + yield defer_template_error(ex, TemplateContext.current().template_value, is_expression=False) + # DTFIX-FUTURE: We should be able to determine if truncation occurred by having the code generator smuggle out the number of expected nodes. + yield TruncationMarker() + else: + if type(node) is TemplateModule: # pylint: disable=unidiomatic-typecheck + yield from _flatten_nodes(node._body_stream) + else: + yield node + + +def _flatten_and_lazify_vars(mapping: c.Mapping) -> t.Iterable[c.Mapping]: + """Prevent deeply-nested Jinja vars ChainMaps from being created by nested contexts and ensure that all top-level containers support lazy templating.""" + mapping_type = type(mapping) + if mapping_type is ChainMap: + # noinspection PyUnresolvedReferences + for m in mapping.maps: + yield from _flatten_and_lazify_vars(m) + elif mapping_type is _AnsibleLazyTemplateDict: + if not mapping: + # DTFIX-RELEASE: handle or remove? + raise Exception("we didn't think it was possible to have an empty lazy here...") + yield mapping + elif mapping_type in (dict, _AnsibleTaggedDict): + # don't propagate empty dictionary layers + if mapping: + yield _AnsibleLazyTemplateMixin._try_create(mapping) + else: + raise NotImplementedError(f"unsupported mapping type in Jinja vars: {mapping_type}") + + +def _new_context( + *, + environment: Environment, + template_name: str | None, + blocks: dict[str, t.Callable[[Context], c.Iterator[str]]], + shared: bool = False, + jinja_locals: c.Mapping[str, t.Any] | None = None, + jinja_vars: c.Mapping[str, t.Any] | None = None, + jinja_globals: c.MutableMapping[str, t.Any] | None = None, +) -> Context: + """Override Jinja's context vars setup to use ChainMaps and containers that support lazy templating.""" + layers = [] + + if jinja_locals: + # DTFIX-RELEASE: if we can't trip this in coverage, kill it off? + if type(jinja_locals) is not dict: # pylint: disable=unidiomatic-typecheck + raise NotImplementedError("locals must be a dict") + + # Omit values set to Jinja's internal `missing` sentinel; they are locals that have not yet been + # initialized in the current context, and should not be exposed to child contexts. e.g.: {% import 'a' as b with context %}. + # The `b` local will be `missing` in the `a` context and should not be propagated as a local to the child context we're creating. + layers.append(_AnsibleLazyTemplateMixin._try_create({k: v for k, v in jinja_locals.items() if v is not missing})) + + if jinja_vars: + layers.extend(_flatten_and_lazify_vars(jinja_vars)) + + if jinja_globals and not shared: + # Even though we don't currently support templating globals, it's easier to ensure that everything is template-able rather than trying to + # pick apart the ChainMaps to enforce non-template-able globals, or to risk things that *should* be template-able not being lazified. + layers.extend(_flatten_and_lazify_vars(jinja_globals)) + + if not layers: + # ensure we have at least one layer (which should be lazy), since _flatten_and_lazify_vars eliminates most empty layers + layers.append(_AnsibleLazyTemplateMixin._try_create({})) + + # only return a ChainMap if we're combining layers, or we have none + parent = layers[0] if len(layers) == 1 else ChainMap(*layers) + + # the `parent` cast is only to satisfy Jinja's overly-strict type hint + return environment.context_class(environment, t.cast(dict, parent), template_name, blocks, globals=jinja_globals) + + +def is_possibly_template(value: str, overrides: TemplateOverrides = TemplateOverrides.DEFAULT): + """ + A lightweight check to determine if the given string looks like it contains a template, even if that template is invalid. + Returns `True` if the given string starts with a Jinja overrides header or if it contains template start strings. + """ + return value.startswith(JINJA2_OVERRIDE) or overrides._contains_start_string(value) + + +def is_possibly_all_template(value: str, overrides: TemplateOverrides = TemplateOverrides.DEFAULT): + """ + A lightweight check to determine if the given string looks like it contains *only* a template, even if that template is invalid. + Returns `True` if the given string starts with a Jinja overrides header or if it starts and ends with Jinja template delimiters. + """ + return value.startswith(JINJA2_OVERRIDE) or overrides._starts_and_ends_with_jinja_delimiters(value) + + +class FinalizeMode(enum.Enum): + TOP_LEVEL = enum.auto() + CONCAT = enum.auto() + + +_FINALIZE_FAST_PATH_EXACT_MAPPING_TYPES = frozenset( + ( + dict, + _AnsibleTaggedDict, + _AnsibleLazyTemplateDict, + HostVars, + HostVarsVars, + ) +) +"""Fast-path exact mapping types for finalization. These types bypass diagnostic warnings for type conversion.""" + +_FINALIZE_FAST_PATH_EXACT_ITERABLE_TYPES = frozenset( + ( + list, + _AnsibleTaggedList, + _AnsibleLazyTemplateList, + tuple, + _AnsibleTaggedTuple, + _AnsibleLazyAccessTuple, + ) +) +"""Fast-path exact iterable types for finalization. These types bypass diagnostic warnings for type conversion.""" + +_FINALIZE_DISALLOWED_EXACT_TYPES = frozenset((range,)) +"""Exact types that cannot be finalized.""" + +# Jinja passes these into filters/tests via @pass_environment +register_known_types( + AnsibleContext, + AnsibleEnvironment, + EvalContext, +) + + +def _finalize_dict(o: t.Any, mode: FinalizeMode) -> t.Iterator[tuple[t.Any, t.Any]]: + for k, v in o.items(): + if v is not Omit: + yield _finalize_template_result(k, mode), _finalize_template_result(v, mode) + + +def _finalize_list(o: t.Any, mode: FinalizeMode) -> t.Iterator[t.Any]: + for v in o: + if v is not Omit: + yield _finalize_template_result(v, mode) + + +def _maybe_finalize_scalar(o: t.Any) -> t.Any: + # DTFIX-RELEASE: this should check all supported scalar subclasses, not just JSON ones (also, does the JSON serializer handle these cases?) + for target_type in _json_subclassable_scalar_types: + if not isinstance(o, target_type): + continue + + match _TemplateConfig.unknown_type_conversion_handler.action: + # we don't want to show the object value, and it can't be Origin-tagged; send the current template value for best effort + case ErrorAction.WARN: + display.warning( + msg=f'Type {native_type_name(o)!r} is unsupported in variable storage, converting to {native_type_name(target_type)!r}.', + obj=TemplateContext.current(optional=True).template_value, + ) + case ErrorAction.FAIL: + raise AnsibleVariableTypeError.from_value(obj=TemplateContext.current(optional=True).template_value) + + return target_type(o) + + return None + + +def _finalize_fallback_collection( + o: t.Any, + mode: FinalizeMode, + finalizer: t.Callable[[t.Any, FinalizeMode], t.Iterator], + target_type: type[list | dict], +) -> t.Collection[t.Any]: + match _TemplateConfig.unknown_type_conversion_handler.action: + # we don't want to show the object value, and it can't be Origin-tagged; send the current template value for best effort + case ErrorAction.WARN: + display.warning( + msg=f'Type {native_type_name(o)!r} is unsupported in variable storage, converting to {native_type_name(target_type)!r}.', + obj=TemplateContext.current(optional=True).template_value, + ) + case ErrorAction.FAIL: + raise AnsibleVariableTypeError.from_value(obj=TemplateContext.current(optional=True).template_value) + + return _finalize_collection(o, mode, finalizer, target_type) + + +def _finalize_collection( + o: t.Any, + mode: FinalizeMode, + finalizer: t.Callable[[t.Any, FinalizeMode], t.Iterator], + target_type: type[list | dict], +) -> t.Collection[t.Any]: + return AnsibleTagHelper.tag(finalizer(o, mode), AnsibleTagHelper.tags(o), value_type=target_type) + + +def _finalize_template_result(o: t.Any, mode: FinalizeMode) -> t.Any: + """Recurse the template result, rendering any encountered templates, converting containers to non-lazy versions.""" + # DTFIX-RELEASE: add tests to ensure this method doesn't drift from allowed types + o_type = type(o) + + # DTFIX-FUTURE: provide an optional way to check for trusted templates leaking out of templating (injected, but not passed through templar.template) + + if o_type is _AnsibleTaggedStr: + return _JinjaConstTemplate.untag(o) # prevent _JinjaConstTemplate from leaking into finalized results + + if o_type in PASS_THROUGH_SCALAR_VAR_TYPES: + return o + + if o_type in _FINALIZE_FAST_PATH_EXACT_MAPPING_TYPES: # silently convert known mapping types to dict + return _finalize_collection(o, mode, _finalize_dict, dict) + + if o_type in _FINALIZE_FAST_PATH_EXACT_ITERABLE_TYPES: # silently convert known sequence types to list + return _finalize_collection(o, mode, _finalize_list, list) + + if o_type in Marker.concrete_subclasses: # this early return assumes handle_marker follows our variable type rules + return TemplateContext.current().templar.marker_behavior.handle_marker(o) + + if mode is not FinalizeMode.TOP_LEVEL: # unsupported type (do not raise) + return o + + if o_type in _FINALIZE_DISALLOWED_EXACT_TYPES: # early abort for disallowed types that would otherwise be handled below + raise AnsibleVariableTypeError.from_value(obj=o) + + if _internal.is_intermediate_mapping(o): # since isinstance checks are slower, this is separate from the exact type check above + return _finalize_fallback_collection(o, mode, _finalize_dict, dict) + + if _internal.is_intermediate_iterable(o): # since isinstance checks are slower, this is separate from the exact type check above + return _finalize_fallback_collection(o, mode, _finalize_list, list) + + if (result := _maybe_finalize_scalar(o)) is not None: + return result + + raise AnsibleVariableTypeError.from_value(obj=o) diff --git a/lib/ansible/_internal/_templating/_jinja_common.py b/lib/ansible/_internal/_templating/_jinja_common.py new file mode 100644 index 00000000000..c2b704f8dee --- /dev/null +++ b/lib/ansible/_internal/_templating/_jinja_common.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import abc +import collections.abc as c +import inspect +import itertools +import typing as t + +from jinja2 import UndefinedError, StrictUndefined, TemplateRuntimeError +from jinja2.utils import missing + +from ansible.module_utils.common.messages import ErrorSummary, Detail +from ansible.constants import config +from ansible.errors import AnsibleUndefinedVariable, AnsibleTypeError +from ansible._internal._errors._handler import ErrorHandler +from ansible.module_utils._internal._datatag import Tripwire, _untaggable_types + +from ._access import NotifiableAccessContextBase +from ._jinja_patches import _patch_jinja +from ._utils import TemplateContext +from .._errors import _captured +from ...module_utils.datatag import native_type_name + +_patch_jinja() # apply Jinja2 patches before types are declared that are dependent on the changes + + +class _TemplateConfig: + allow_embedded_templates: bool = config.get_config_value("ALLOW_EMBEDDED_TEMPLATES") + allow_broken_conditionals: bool = config.get_config_value('ALLOW_BROKEN_CONDITIONALS') + jinja_extensions: list[str] = config.get_config_value('DEFAULT_JINJA2_EXTENSIONS') + + unknown_type_encountered_handler = ErrorHandler.from_config('_TEMPLAR_UNKNOWN_TYPE_ENCOUNTERED') + unknown_type_conversion_handler = ErrorHandler.from_config('_TEMPLAR_UNKNOWN_TYPE_CONVERSION') + untrusted_template_handler = ErrorHandler.from_config('_TEMPLAR_UNTRUSTED_TEMPLATE_BEHAVIOR') + + +class MarkerError(UndefinedError): + """ + An Ansible specific subclass of Jinja's UndefinedError, used to preserve and later restore the original Marker instance that raised the error. + This error is only raised by Marker and should never escape the templating system. + """ + + def __init__(self, message: str, source: Marker) -> None: + super().__init__(message) + + self.source = source + + +class Marker(StrictUndefined, Tripwire): + """ + Extends Jinja's `StrictUndefined`, allowing any kind of error occurring during recursive templating operations to be captured and deferred. + Direct or managed access to most `Marker` attributes will raise a `MarkerError`, which usually ends the current innermost templating + operation and converts the `MarkerError` back to the origin Marker instance (subject to the `MarkerBehavior` in effect at the time). + """ + + __slots__ = ('_marker_template_source',) + + concrete_subclasses: t.ClassVar[set[type[Marker]]] = set() + + def __init__( + self, + hint: t.Optional[str] = None, + obj: t.Any = missing, + name: t.Optional[str] = None, + exc: t.Type[TemplateRuntimeError] = UndefinedError, # Ansible doesn't set this argument or consume the attribute it is stored under. + *args, + _no_template_source=False, + **kwargs, + ) -> None: + if not hint and name and obj is not missing: + hint = f"object of type {native_type_name(obj)!r} has no attribute {name!r}" + + kwargs.update( + hint=hint, + obj=obj, + name=name, + exc=exc, + ) + + super().__init__(*args, **kwargs) + + if _no_template_source: + self._marker_template_source = None + else: + self._marker_template_source = TemplateContext.current().template_value + + def _as_exception(self) -> Exception: + """Return the exception instance to raise in a top-level templating context.""" + return AnsibleUndefinedVariable(self._undefined_message, obj=self._marker_template_source) + + def _as_message(self) -> str: + """Return the error message to show when this marker must be represented as a string, such as for subsitutions or warnings.""" + return self._undefined_message + + def _fail_with_undefined_error(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: + """Ansible-specific replacement for Jinja's _fail_with_undefined_error tripwire on dunder methods.""" + self.trip() + + def trip(self) -> t.NoReturn: + """Raise an internal exception which can be converted back to this instance.""" + raise MarkerError(self._undefined_message, self) + + def __setattr__(self, name: str, value: t.Any) -> None: + """ + Any attempt to set an unknown attribute on a `Marker` should invoke the trip method to propagate the original context. + This does not protect against mutation of known attributes, but the implementation is fairly simple. + """ + try: + super().__setattr__(name, value) + except AttributeError: + pass + else: + return + + self.trip() + + def __getattr__(self, name: str) -> t.Any: + """Raises AttributeError for dunder-looking accesses, self-propagates otherwise.""" + if name.startswith('__') and name.endswith('__'): + raise AttributeError(name) + + return self + + def __getitem__(self, key): + """Self-propagates on all item accesses.""" + return self + + @classmethod + def __init_subclass__(cls, **kwargs) -> None: + if not inspect.isabstract(cls): + _untaggable_types.add(cls) + cls.concrete_subclasses.add(cls) + + @classmethod + def _init_class(cls): + _untaggable_types.add(cls) + + # These are the methods StrictUndefined already intercepts. + jinja_method_names = ( + '__add__', + '__bool__', + '__call__', + '__complex__', + '__contains__', + '__div__', + '__eq__', + '__float__', + '__floordiv__', + '__ge__', + # '__getitem__', # using a custom implementation that propagates self instead + '__gt__', + '__hash__', + '__int__', + '__iter__', + '__le__', + '__len__', + '__lt__', + '__mod__', + '__mul__', + '__ne__', + '__neg__', + '__pos__', + '__pow__', + '__radd__', + '__rdiv__', + '__rfloordiv__', + '__rmod__', + '__rmul__', + '__rpow__', + '__rsub__', + '__rtruediv__', + '__str__', + '__sub__', + '__truediv__', + ) + + # These additional methods should be intercepted, even though they are not intercepted by StrictUndefined. + additional_method_names = ( + '__aiter__', + '__delattr__', + '__format__', + '__repr__', + '__setitem__', + ) + + for name in jinja_method_names + additional_method_names: + setattr(cls, name, cls._fail_with_undefined_error) + + +Marker._init_class() + + +class TruncationMarker(Marker): + """ + An `Marker` value was previously encountered and reported. + A subsequent `Marker` value (this instance) indicates the template may have been truncated as a result. + It will only be visible if the previous `Marker` was ignored/replaced instead of being tripped, which would raise an exception. + """ + + # DTFIX-RELEASE: make this a singleton? + + __slots__ = () + + def __init__(self) -> None: + super().__init__(hint='template potentially truncated') + + +class UndefinedMarker(Marker): + """A `Marker` value that represents an undefined value encountered during templating.""" + + __slots__ = () + + +class ExceptionMarker(Marker, metaclass=abc.ABCMeta): + """Base `Marker` class that represents exceptions encountered and deferred during templating.""" + + __slots__ = () + + @abc.abstractmethod + def _as_exception(self) -> Exception: + pass + + def _as_message(self) -> str: + return str(self._as_exception()) + + def trip(self) -> t.NoReturn: + """Raise an internal exception which can be converted back to this instance while maintaining the cause for callers that follow them.""" + raise MarkerError(self._undefined_message, self) from self._as_exception() + + +class CapturedExceptionMarker(ExceptionMarker): + """A `Marker` value that represents an exception raised during templating.""" + + __slots__ = ('_marker_captured_exception',) + + def __init__(self, exception: Exception) -> None: + super().__init__(hint=f'A captured exception marker was tripped: {exception}') + + self._marker_captured_exception = exception + + def _as_exception(self) -> Exception: + return self._marker_captured_exception + + +class UndecryptableVaultError(_captured.AnsibleCapturedError): + """Template-external error raised by VaultExceptionMarker when an undecryptable variable is accessed.""" + + context = 'vault' + _default_message = "Attempt to use undecryptable variable." + + +class VaultExceptionMarker(ExceptionMarker): + """A `Marker` value that represents an error accessing a vaulted value during templating.""" + + __slots__ = ('_marker_undecryptable_ciphertext', '_marker_undecryptable_reason', '_marker_undecryptable_traceback') + + def __init__(self, ciphertext: str, reason: str, traceback: str | None) -> None: + # DTFIX-RELEASE: when does this show up, should it contain more details? + # see also CapturedExceptionMarker for a similar issue + super().__init__(hint='A vault exception marker was tripped.') + + self._marker_undecryptable_ciphertext = ciphertext + self._marker_undecryptable_reason = reason + self._marker_undecryptable_traceback = traceback + + def _as_exception(self) -> Exception: + return UndecryptableVaultError( + obj=self._marker_undecryptable_ciphertext, + error_summary=ErrorSummary( + details=( + Detail( + msg=self._marker_undecryptable_reason, + ), + ), + formatted_traceback=self._marker_undecryptable_traceback, + ), + ) + + def _disarm(self) -> str: + return self._marker_undecryptable_ciphertext + + +def get_first_marker_arg(args: c.Sequence, kwargs: dict[str, t.Any]) -> Marker | None: + """Utility method to inspect plugin args and return the first `Marker` encountered, otherwise `None`.""" + # DTFIX-RELEASE: this may or may not need to be public API, move back to utils or once usage is wrapped in a decorator? + for arg in iter_marker_args(args, kwargs): + return arg + + return None + + +def iter_marker_args(args: c.Sequence, kwargs: dict[str, t.Any]) -> t.Generator[Marker]: + """Utility method to iterate plugin args and yield any `Marker` encountered.""" + # DTFIX-RELEASE: this may or may not need to be public API, move back to utils or once usage is wrapped in a decorator? + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, Marker): + yield arg + + +class JinjaCallContext(NotifiableAccessContextBase): + """ + An audit context that wraps all Jinja (template/filter/test/lookup/method/function) calls. + While active, calls `trip()` on managed access of `Marker` objects unless the callee declares an understanding of markers. + """ + + _mask = True + + def __init__(self, accept_lazy_markers: bool) -> None: + self._type_interest = frozenset() if accept_lazy_markers else frozenset(Marker.concrete_subclasses) + + def _notify(self, o: Marker) -> t.NoReturn: + o.trip() + + +def validate_arg_type(name: str, value: t.Any, allowed_type_or_types: type | tuple[type, ...], /) -> None: + """Validate the type of the given argument while preserving context for Marker values.""" + # DTFIX-RELEASE: find a home for this as a general-purpose utliity method and expose it after some API review + if isinstance(value, allowed_type_or_types): + return + + if isinstance(allowed_type_or_types, type): + arg_type_description = repr(native_type_name(allowed_type_or_types)) + else: + arg_type_description = ' or '.join(repr(native_type_name(item)) for item in allowed_type_or_types) + + if isinstance(value, Marker): + try: + value.trip() + except Exception as ex: + raise AnsibleTypeError(f"The {name!r} argument must be of type {arg_type_description}.", obj=value) from ex + + raise AnsibleTypeError(f"The {name!r} argument must be of type {arg_type_description}, not {native_type_name(value)!r}.", obj=value) diff --git a/lib/ansible/_internal/_templating/_jinja_patches.py b/lib/ansible/_internal/_templating/_jinja_patches.py new file mode 100644 index 00000000000..55966793e47 --- /dev/null +++ b/lib/ansible/_internal/_templating/_jinja_patches.py @@ -0,0 +1,44 @@ +"""Runtime patches for Jinja bugs affecting Ansible.""" + +from __future__ import annotations + +import jinja2 +import jinja2.utils + + +def _patch_jinja_undefined_slots() -> None: + """ + Fix the broken __slots__ on Jinja's Undefined and StrictUndefined if they're missing in the current version. + This will no longer be necessary once the fix is included in the minimum supported Jinja version. + See: https://github.com/pallets/jinja/issues/2025 + """ + if not hasattr(jinja2.Undefined, '__slots__'): + jinja2.Undefined.__slots__ = ( + "_undefined_hint", + "_undefined_obj", + "_undefined_name", + "_undefined_exception", + ) + + if not hasattr(jinja2.StrictUndefined, '__slots__'): + jinja2.StrictUndefined.__slots__ = () + + +def _patch_jinja_missing_type() -> None: + """ + Fix the `jinja2.utils.missing` type to support pickling while remaining a singleton. + This will no longer be necessary once the fix is included in the minimum supported Jinja version. + See: https://github.com/pallets/jinja/issues/2027 + """ + if getattr(jinja2.utils.missing, '__reduce__')() != 'missing': + + def __reduce__(*_args): + return 'missing' + + type(jinja2.utils.missing).__reduce__ = __reduce__ + + +def _patch_jinja() -> None: + """Apply Jinja2 patches.""" + _patch_jinja_undefined_slots() + _patch_jinja_missing_type() diff --git a/lib/ansible/_internal/_templating/_jinja_plugins.py b/lib/ansible/_internal/_templating/_jinja_plugins.py new file mode 100644 index 00000000000..e68d96dcf5d --- /dev/null +++ b/lib/ansible/_internal/_templating/_jinja_plugins.py @@ -0,0 +1,351 @@ +"""Jinja template plugins (filters, tests, lookups) and custom global functions.""" + +from __future__ import annotations + +import collections.abc as c +import dataclasses +import datetime +import functools +import typing as t + +from ansible.errors import ( + AnsibleTemplatePluginError, +) + +from ansible.module_utils._internal._ambient_context import AmbientContextBase +from ansible.module_utils._internal._plugin_exec_context import PluginExecContext +from ansible.module_utils.common.collections import is_sequence +from ansible.module_utils._internal._datatag import AnsibleTagHelper +from ansible._internal._datatag._tags import TrustedAsTemplate +from ansible.plugins import AnsibleJinja2Plugin +from ansible.plugins.loader import lookup_loader, Jinja2Loader +from ansible.plugins.lookup import LookupBase +from ansible.utils.display import Display + +from ._datatag import _JinjaConstTemplate +from ._errors import AnsibleTemplatePluginRuntimeError, AnsibleTemplatePluginLoadError, AnsibleTemplatePluginNotFoundError +from ._jinja_common import MarkerError, _TemplateConfig, get_first_marker_arg, Marker, JinjaCallContext +from ._lazy_containers import lazify_container_kwargs, lazify_container_args, lazify_container, _AnsibleLazyTemplateMixin +from ._utils import LazyOptions, TemplateContext + +_display = Display() + +_TCallable = t.TypeVar("_TCallable", bound=t.Callable) +_ITERATOR_TYPES: t.Final = (c.Iterator, c.ItemsView, c.KeysView, c.ValuesView, range) + + +class JinjaPluginIntercept(c.MutableMapping): + """ + Simulated dict class that loads Jinja2Plugins at request + otherwise all plugins would need to be loaded a priori. + + NOTE: plugin_loader still loads all 'builtin/legacy' at + start so only collection plugins are really at request. + """ + + def __init__(self, jinja_builtins: c.Mapping[str, AnsibleJinja2Plugin], plugin_loader: Jinja2Loader): + super(JinjaPluginIntercept, self).__init__() + + self._plugin_loader = plugin_loader + self._jinja_builtins = jinja_builtins + self._wrapped_funcs: dict[str, t.Callable] = {} + + def _wrap_and_set_func(self, instance: AnsibleJinja2Plugin) -> t.Callable: + if self._plugin_loader.type == 'filter': + plugin_func = self._wrap_filter(instance) + else: + plugin_func = self._wrap_test(instance) + + self._wrapped_funcs[instance._load_name] = plugin_func + + return plugin_func + + def __getitem__(self, key: str) -> t.Callable: + instance: AnsibleJinja2Plugin | None = None + plugin_func: t.Callable[..., t.Any] | None + + if plugin_func := self._wrapped_funcs.get(key): + return plugin_func + + try: + instance = self._plugin_loader.get(key) + except KeyError: + # The plugin name was invalid or no plugin was found by that name. + pass + except Exception as ex: + # An unexpected exception occurred. + raise AnsibleTemplatePluginLoadError(self._plugin_loader.type, key) from ex + + if not instance: + try: + instance = self._jinja_builtins[key] + except KeyError: + raise AnsibleTemplatePluginNotFoundError(self._plugin_loader.type, key) from None + + plugin_func = self._wrap_and_set_func(instance) + + return plugin_func + + def __setitem__(self, key: str, value: t.Callable) -> None: + self._wrap_and_set_func(self._plugin_loader._wrap_func(key, key, value)) + + def __delitem__(self, key): + raise NotImplementedError() + + def __contains__(self, item: t.Any) -> bool: + try: + self.__getitem__(item) + except AnsibleTemplatePluginLoadError: + return True + except AnsibleTemplatePluginNotFoundError: + return False + + return True + + def __iter__(self): + raise NotImplementedError() # dynamic container + + def __len__(self): + raise NotImplementedError() # dynamic container + + @staticmethod + def _invoke_plugin(instance: AnsibleJinja2Plugin, *args, **kwargs) -> t.Any: + if not instance.accept_args_markers: + if (first_marker := get_first_marker_arg(args, kwargs)) is not None: + return first_marker + + try: + with JinjaCallContext(accept_lazy_markers=instance.accept_lazy_markers), PluginExecContext(executing_plugin=instance): + return instance.j2_function(*lazify_container_args(args), **lazify_container_kwargs(kwargs)) + except MarkerError as ex: + return ex.source + except Exception as ex: + raise AnsibleTemplatePluginRuntimeError(instance.plugin_type, instance.ansible_name) from ex # DTFIX-RELEASE: which name to use? use plugin info? + + def _wrap_test(self, instance: AnsibleJinja2Plugin) -> t.Callable: + """Intercept point for all test plugins to ensure that args are properly templated/lazified.""" + + @functools.wraps(instance.j2_function) + def wrapper(*args, **kwargs) -> bool | Marker: + result = self._invoke_plugin(instance, *args, **kwargs) + + if not isinstance(result, bool): + template = TemplateContext.current().template_value + + # DTFIX-RELEASE: which name to use? use plugin info? + _display.deprecated( + msg=f"The test plugin {instance.ansible_name!r} returned a non-boolean result of type {type(result)!r}. " + "Test plugins must have a boolean result.", + obj=template, + version="2.23", + ) + + result = bool(result) + + return result + + return wrapper + + def _wrap_filter(self, instance: AnsibleJinja2Plugin) -> t.Callable: + """Intercept point for all filter plugins to ensure that args are properly templated/lazified.""" + + @functools.wraps(instance.j2_function) + def wrapper(*args, **kwargs) -> t.Any: + result = self._invoke_plugin(instance, *args, **kwargs) + result = _wrap_plugin_output(result) + + return result + + return wrapper + + +class _DirectCall: + """Functions/methods marked `_DirectCall` bypass Jinja Environment checks for `Marker`.""" + + _marker_attr: str = "_directcall" + + @classmethod + def mark(cls, src: _TCallable) -> _TCallable: + setattr(src, cls._marker_attr, True) + return src + + @classmethod + def is_marked(cls, value: t.Callable) -> bool: + return callable(value) and getattr(value, "_directcall", False) + + +@_DirectCall.mark +def _query(plugin_name: str, /, *args, **kwargs) -> t.Any: + """wrapper for lookup, force wantlist true""" + kwargs['wantlist'] = True + return _invoke_lookup(plugin_name=plugin_name, lookup_terms=list(args), lookup_kwargs=kwargs) + + +@_DirectCall.mark +def _lookup(plugin_name: str, /, *args, **kwargs) -> t.Any: + # convert the args tuple to a list, since some plugins make a poor assumption that `run.args` is a list + return _invoke_lookup(plugin_name=plugin_name, lookup_terms=list(args), lookup_kwargs=kwargs) + + +@dataclasses.dataclass +class _LookupContext(AmbientContextBase): + """Ambient context that wraps lookup execution, providing information about how it was invoked.""" + + invoked_as_with: bool + + +@_DirectCall.mark +def _invoke_lookup(*, plugin_name: str, lookup_terms: list, lookup_kwargs: dict[str, t.Any], invoked_as_with: bool = False) -> t.Any: + templar = TemplateContext.current().templar + + from ansible import template as _template + + try: + instance: LookupBase | None = lookup_loader.get(plugin_name, loader=templar._loader, templar=_template.Templar._from_template_engine(templar)) + except Exception as ex: + raise AnsibleTemplatePluginLoadError('lookup', plugin_name) from ex + + if instance is None: + raise AnsibleTemplatePluginNotFoundError('lookup', plugin_name) + + # if the lookup doesn't understand `Marker` and there's at least one in the top level, short-circuit by returning the first one we found + if not instance.accept_args_markers and (first_marker := get_first_marker_arg(lookup_terms, lookup_kwargs)) is not None: + return first_marker + + # don't pass these through to the lookup + wantlist = lookup_kwargs.pop('wantlist', False) + errors = lookup_kwargs.pop('errors', 'strict') + + with ( + JinjaCallContext(accept_lazy_markers=instance.accept_lazy_markers), + PluginExecContext(executing_plugin=instance), + ): + try: + if _TemplateConfig.allow_embedded_templates: + # for backwards compat, only trust constant templates in lookup terms + with JinjaCallContext(accept_lazy_markers=True): + # Force lazy marker support on for this call; the plugin's understanding is irrelevant, as is any existing context, since this backward + # compat code always understands markers. + lookup_terms = [templar.template(value) for value in _trust_jinja_constants(lookup_terms)] + + # since embedded template support is enabled, repeat the check for `Marker` on lookup_terms, since a template may render as a `Marker` + if not instance.accept_args_markers and (first_marker := get_first_marker_arg(lookup_terms, {})) is not None: + return first_marker + else: + lookup_terms = AnsibleTagHelper.tag_copy(lookup_terms, (lazify_container(value) for value in lookup_terms), value_type=list) + + with _LookupContext(invoked_as_with=invoked_as_with): + # The lookup context currently only supports the internal use-case where `first_found` requires extra info when invoked via `with_first_found`. + # The context may be public API in the future, but for now, other plugins should not implement this kind of dynamic behavior, + # though we're stuck with it for backward compatibility on `first_found`. + lookup_res = instance.run(lookup_terms, variables=templar.available_variables, **lazify_container_kwargs(lookup_kwargs)) + + # DTFIX-FUTURE: Consider allowing/requiring lookup plugins to declare how their result should be handled. + # Currently, there are multiple behaviors that are less than ideal and poorly documented (or not at all): + # * When `errors=warn` or `errors=ignore` the result is `None` unless `wantlist=True`, in which case the result is `[]`. + # * The user must specify `wantlist=True` to receive the plugin return value unmodified. + # A plugin can achieve similar results by wrapping its result in a list -- unless of course the user specifies `wantlist=True`. + # * When `wantlist=True` is specified, the result is not guaranteed to be a list as the option implies (except on plugin error). + # * Sequences are munged unless the user specifies `wantlist=True`: + # * len() == 0 - Return an empty sequence. + # * len() == 1 - Return the only element in the sequence. + # * len() >= 2 when all elements are `str` - Return all the values joined into a single comma separated string. + # * len() >= 2 when at least one element is not `str` - Return the sequence as-is. + + if not is_sequence(lookup_res): + # DTFIX-FUTURE: deprecate return types which are not a list + # previously non-Sequence return types were deprecated and then became an error in 2.18 + # however, the deprecation message (and this error) mention `list` specifically rather than `Sequence` + # letting non-list values through will trigger variable type checking warnings/errors + raise TypeError(f'returned {type(lookup_res)} instead of {list}') + + except MarkerError as ex: + return ex.source + except Exception as ex: + # DTFIX-RELEASE: convert this to the new error/warn/ignore context manager + if isinstance(ex, AnsibleTemplatePluginError): + msg = f'Lookup failed but the error is being ignored: {ex}' + else: + msg = f'An unhandled exception occurred while running the lookup plugin {plugin_name!r}. Error was a {type(ex)}, original message: {ex}' + + if errors == 'warn': + _display.warning(msg) + elif errors == 'ignore': + _display.display(msg, log_only=True) + else: + raise AnsibleTemplatePluginRuntimeError('lookup', plugin_name) from ex + + return [] if wantlist else None + + if not wantlist and lookup_res: + # when wantlist=False the lookup result is either partially delaizified (single element) or fully delaizified (multiple elements) + + if len(lookup_res) == 1: + lookup_res = lookup_res[0] + else: + try: + lookup_res = ",".join(lookup_res) # for backwards compatibility, attempt to join `ran` into single string + except TypeError: + pass # for backwards compatibility, return `ran` as-is when the sequence contains non-string values + + return _wrap_plugin_output(lookup_res) + + +def _now(utc=False, fmt=None): + """Jinja2 global function (now) to return current datetime, potentially formatted via strftime.""" + if utc: + now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + else: + now = datetime.datetime.now() + + if fmt: + return now.strftime(fmt) + + return now + + +def _jinja_const_template_warning(value: object, is_conditional: bool) -> None: + """Issue a warning regarding embedded template usage.""" + help_text = "Use inline expressions, for example: " + + if is_conditional: + help_text += """`when: "{{ a_var }}" == 42` becomes `when: a_var == 42`""" + else: + help_text += """`msg: "{{ lookup('env', '{{ a_var }}') }}"` becomes `msg: "{{ lookup('env', a_var) }}"`""" + + # deprecated: description='disable embedded templates by default and deprecate the feature' core_version='2.23' + _display.warning( + msg="Jinja constant strings should not contain embedded templates. This feature will be disabled by default in ansible-core 2.23.", + obj=value, + help_text=help_text, + ) + + +def _trust_jinja_constants(o: t.Any) -> t.Any: + """ + Recursively apply TrustedAsTemplate to values tagged with _JinjaConstTemplate and remove the tag. + Only container types emitted by the Jinja compiler are checked, since others do not contain constants. + This is used to provide backwards compatibility with historical lookup behavior for positional arguments. + """ + if _JinjaConstTemplate.is_tagged_on(o): + _jinja_const_template_warning(o, is_conditional=False) + + return TrustedAsTemplate().tag(_JinjaConstTemplate.untag(o)) + + o_type = type(o) + + if o_type is dict: + return {k: _trust_jinja_constants(v) for k, v in o.items()} + + if o_type in (list, tuple): + return o_type(_trust_jinja_constants(v) for v in o) + + return o + + +def _wrap_plugin_output(o: t.Any) -> t.Any: + """Utility method to ensure that iterators/generators returned from a plugins are consumed.""" + if isinstance(o, _ITERATOR_TYPES): + o = list(o) + + return _AnsibleLazyTemplateMixin._try_create(o, LazyOptions.SKIP_TEMPLATES) diff --git a/lib/ansible/_internal/_templating/_lazy_containers.py b/lib/ansible/_internal/_templating/_lazy_containers.py new file mode 100644 index 00000000000..b1a7a4f2310 --- /dev/null +++ b/lib/ansible/_internal/_templating/_lazy_containers.py @@ -0,0 +1,633 @@ +from __future__ import annotations + +import copy +import dataclasses +import functools +import types +import typing as t + +from jinja2.environment import TemplateModule + +from ansible.module_utils._internal._datatag import ( + AnsibleTagHelper, + AnsibleTaggedObject, + _AnsibleTaggedDict, + _AnsibleTaggedList, + _AnsibleTaggedTuple, + _NO_INSTANCE_STORAGE, + _try_get_internal_tags_mapping, +) + +from ansible.utils.sentinel import Sentinel +from ansible.errors import AnsibleVariableTypeError +from ansible._internal._errors._handler import Skippable +from ansible.vars.hostvars import HostVarsVars, HostVars + +from ._access import AnsibleAccessContext +from ._jinja_common import Marker, _TemplateConfig +from ._utils import TemplateContext, PASS_THROUGH_SCALAR_VAR_TYPES, LazyOptions + +if t.TYPE_CHECKING: + from ._engine import TemplateEngine + +_KNOWN_TYPES: t.Final[set[type]] = ( + { + HostVars, # example: hostvars + HostVarsVars, # example: hostvars.localhost | select + type, # example: range(20) | list # triggered on retrieval of `range` type from globals + range, # example: range(20) | list # triggered when returning a `range` instance from a call + types.FunctionType, # example: undef() | default("blah") + types.MethodType, # example: ansible_facts.get | type_debug + functools.partial, + type(''.startswith), # example: inventory_hostname.upper | type_debug # using `startswith` to resolve `builtin_function_or_method` + TemplateModule, # example: '{% import "importme.j2" as im %}{{ im | type_debug }}' + } + | set(PASS_THROUGH_SCALAR_VAR_TYPES) + | set(Marker.concrete_subclasses) +) +""" +These types are known to the templating system. +In addition to the statically defined types, additional types will be added at runtime. +When enabled in config, this set will be used to determine if an encountered type should trigger a warning or error. +""" + + +def register_known_types(*args: type) -> None: + """Register a type with the template engine so it will not trigger warnings or errors when encountered.""" + _KNOWN_TYPES.update(args) + + +class UnsupportedConstructionMethodError(RuntimeError): + """Error raised when attempting to construct a lazy container with unsupported arguments.""" + + def __init__(self): + super().__init__("Direct construction of lazy containers is not supported.") + + +@t.final +@dataclasses.dataclass(frozen=True, slots=True) +class _LazyValue: + """Wrapper around values to indicate lazy behavior has not yet been applied.""" + + value: t.Any + + +@t.final +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _LazyValueSource: + """Intermediate value source for lazy-eligible collection copy operations.""" + + source: t.Iterable + templar: TemplateEngine + lazy_options: LazyOptions + + +@t.final +class _NoKeySentinel(Sentinel): + """Sentinel used to indicate a requested key was not found.""" + + +# There are several operations performed by lazy containers, with some variation between types. +# +# Columns: D=dict, L=list, T=tuple +# Cells: l=lazy (upon access), n=non-lazy (__init__/__new__) +# +# D L T Feature Description +# - - - ----------- --------------------------------------------------------------- +# l l n propagation when container items which are containers become lazy instances +# l l n transform when transforms are applied to container items +# l l n templating when templating is performed on container items +# l l l access when access calls are performed on container items + + +class _AnsibleLazyTemplateMixin: + __slots__ = _NO_INSTANCE_STORAGE + + _dispatch_types: t.ClassVar[dict[type, type[_AnsibleLazyTemplateMixin]]] = {} # populated by __init_subclass__ + _container_types: t.ClassVar[set[type]] = set() # populated by __init_subclass__ + + _native_type: t.ClassVar[type] # from AnsibleTaggedObject + + _SLOTS: t.Final = ( + '_templar', + '_lazy_options', + ) + + _templar: TemplateEngine + _lazy_options: LazyOptions + + def __init_subclass__(cls, **kwargs) -> None: + tagged_type = cls.__mro__[1] + native_type = tagged_type.__mro__[1] + + for check_type in (tagged_type, native_type): + if conflicting_type := cls._dispatch_types.get(check_type): + raise TypeError(f"Lazy mixin {cls.__name__!r} type {check_type.__name__!r} conflicts with {conflicting_type.__name__!r}.") + + cls._dispatch_types[native_type] = cls + cls._dispatch_types[tagged_type] = cls + cls._container_types.add(native_type) + cls._empty_tags_as_native = False # never revert to the native type when no tags remain + + register_known_types(cls) + + def __init__(self, contents: t.Iterable | _LazyValueSource) -> None: + if isinstance(contents, _LazyValueSource): + self._templar = contents.templar + self._lazy_options = contents.lazy_options + elif isinstance(contents, _AnsibleLazyTemplateMixin): + self._templar = contents._templar + self._lazy_options = contents._lazy_options + else: + raise UnsupportedConstructionMethodError() + + def __reduce_ex__(self, protocol): + raise NotImplementedError("Pickling of Ansible lazy objects is not permitted.") + + @staticmethod + def _try_create(item: t.Any, lazy_options: LazyOptions = LazyOptions.DEFAULT) -> t.Any: + """ + If `item` is a container type which supports lazy access and/or templating, return a lazy wrapped version -- otherwise return it as-is. + When returning as-is, a warning or error may be generated for unknown types. + The `lazy_options.skip_templates` argument should be set to `True` when `item` is sourced from a plugin instead of Ansible variable storage. + This provides backwards compatibility and reduces lazy overhead, as plugins do not normally introduce templates. + If a plugin needs to introduce templates, the plugin is responsible for invoking the templar and returning the result. + """ + item_type = type(item) + + # Try to use exact type match first to determine which wrapper (if any) to apply; isinstance checks + # are extremely expensive, so try to avoid them for our commonly-supported types. + if (dispatcher := _AnsibleLazyTemplateMixin._dispatch_types.get(item_type)) is not None: + # Create a generator that yields the elements of `item` wrapped in a `_LazyValue` wrapper. + # The wrapper is used to signal to the lazy container that the value must be processed before being returned. + # Values added to the lazy container later through other means will be returned as-is, without any special processing. + lazy_values = dispatcher._lazy_values(item, lazy_options) + tags_mapping = _try_get_internal_tags_mapping(item) + value = t.cast(AnsibleTaggedObject, dispatcher)._instance_factory(lazy_values, tags_mapping) + + return value + + with Skippable, _TemplateConfig.unknown_type_encountered_handler.handle(AnsibleVariableTypeError, skip_on_ignore=True): + if item_type not in _KNOWN_TYPES: + raise AnsibleVariableTypeError( + message=f"Encountered unknown type {item_type.__name__!r} during template operation.", + help_text="Use supported types to avoid unexpected behavior.", + obj=TemplateContext.current().template_value, + ) + + return item + + def _is_not_lazy_combine_candidate(self, other: object) -> bool: + """Returns `True` if `other` cannot be lazily combined with the current instance due to differing templar/options, otherwise returns `False`.""" + return isinstance(other, _AnsibleLazyTemplateMixin) and (self._templar is not other._templar or self._lazy_options != other._lazy_options) + + def _non_lazy_copy(self) -> t.Collection: + """ + Return a non-lazy copy of this collection. + Any remaining lazy wrapped values will be unwrapped without further processing. + Tags on this instance will be preserved on the returned copy. + """ + raise NotImplementedError() # pragma: nocover + + @staticmethod + def _lazy_values(values: t.Any, lazy_options: LazyOptions) -> _LazyValueSource: + """ + Return an iterable that wraps each of the given elements in a lazy wrapper. + Only elements wrapped this way will receive lazy processing when retrieved from the collection. + """ + # DTFIX-RELEASE: check relative performance of method-local vs stored generator expressions on implementations of this method + raise NotImplementedError() # pragma: nocover + + def _proxy_or_render_lazy_value(self, key: t.Any, value: t.Any) -> t.Any: + """ + Ensure that the value is lazy-proxied or rendered, and if a key is provided, replace the original value with the result. + """ + if type(value) is not _LazyValue: # pylint: disable=unidiomatic-typecheck + if self._lazy_options.access: + AnsibleAccessContext.current().access(value) + + return value + + original_value = value.value + + if self._lazy_options.access: + AnsibleAccessContext.current().access(original_value) + + new_value = self._templar.template(original_value, lazy_options=self._lazy_options) + + if new_value is not original_value and self._lazy_options.access: + AnsibleAccessContext.current().access(new_value) + + if key is not _NoKeySentinel: + self._native_type.__setitem__(self, key, new_value) # type: ignore # pylint: disable=unnecessary-dunder-call + + return new_value + + +@t.final # consumers of lazy collections rely heavily on the concrete types being final +class _AnsibleLazyTemplateDict(_AnsibleTaggedDict, _AnsibleLazyTemplateMixin): + __slots__ = _AnsibleLazyTemplateMixin._SLOTS + + def __init__(self, contents: t.Iterable | _LazyValueSource, /, **kwargs) -> None: + _AnsibleLazyTemplateMixin.__init__(self, contents) + + if isinstance(contents, _AnsibleLazyTemplateDict): + super().__init__(dict.items(contents), **kwargs) + elif isinstance(contents, _LazyValueSource): + super().__init__(contents.source, **kwargs) + else: + raise UnsupportedConstructionMethodError() + + def get(self, key: t.Any, default: t.Any = None) -> t.Any: + if (value := super().get(key, _NoKeySentinel)) is _NoKeySentinel: + return default + + return self._proxy_or_render_lazy_value(key, value) + + def __getitem__(self, key: t.Any, /) -> t.Any: + return self._proxy_or_render_lazy_value(key, super().__getitem__(key)) + + def __str__(self): + return str(self.copy()._native_copy()) # inefficient, but avoids mutating the current instance (to make debugging practical) + + def __repr__(self): + return repr(self.copy()._native_copy()) # inefficient, but avoids mutating the current instance (to make debugging practical) + + def __iter__(self): + # We're using the base implementation, but must override `__iter__` to skip `dict` fast-path copy, which would bypass lazy behavior. + # See: https://github.com/python/cpython/blob/ffcc450a9b8b6927549b501eff7ac14abc238448/Objects/dictobject.c#L3861-L3864 + return super().__iter__() + + def setdefault(self, key, default=None, /) -> t.Any: + if (value := self.get(key, _NoKeySentinel)) is not _NoKeySentinel: + return value + + super().__setitem__(key, default) + + return default + + def items(self): + for key, value in super().items(): + yield key, self._proxy_or_render_lazy_value(key, value) + + def values(self): + for _key, value in self.items(): + yield value + + def pop(self, key, default=_NoKeySentinel, /) -> t.Any: + if (value := super().get(key, _NoKeySentinel)) is _NoKeySentinel: + if default is _NoKeySentinel: + raise KeyError(key) + + return default + + value = self._proxy_or_render_lazy_value(_NoKeySentinel, value) + + del self[key] + + return value + + def popitem(self) -> t.Any: + try: + key = next(reversed(self)) + except StopIteration: + raise KeyError("popitem(): dictionary is empty") + + value = self._proxy_or_render_lazy_value(_NoKeySentinel, self[key]) + + del self[key] + + return key, value + + def _native_copy(self) -> dict: + return dict(self.items()) + + @staticmethod + def _item_source(value: dict) -> dict | _LazyValueSource: + if isinstance(value, _AnsibleLazyTemplateDict): + return _LazyValueSource(source=dict.items(value), templar=value._templar, lazy_options=value._lazy_options) + + return value + + def _yield_non_lazy_dict_items(self) -> t.Iterator[tuple[str, t.Any]]: + """ + Delegate to the base collection items iterator to yield the raw contents. + As of Python 3.13, generator functions are significantly faster than inline generator expressions. + """ + for k, v in dict.items(self): + yield k, v.value if type(v) is _LazyValue else v # pylint: disable=unidiomatic-typecheck + + def _non_lazy_copy(self) -> dict: + return AnsibleTagHelper.tag_copy(self, self._yield_non_lazy_dict_items(), value_type=dict) + + @staticmethod + def _lazy_values(values: dict, lazy_options: LazyOptions) -> _LazyValueSource: + return _LazyValueSource(source=((k, _LazyValue(v)) for k, v in values.items()), templar=TemplateContext.current().templar, lazy_options=lazy_options) + + @staticmethod + def _proxy_or_render_other(other: t.Any | None) -> None: + """Call `_proxy_or_render_lazy_values` if `other` is a lazy dict. Used internally by comparison methods.""" + if type(other) is _AnsibleLazyTemplateDict: # pylint: disable=unidiomatic-typecheck + other._proxy_or_render_lazy_values() + + def _proxy_or_render_lazy_values(self) -> None: + """Ensure all `_LazyValue` wrapped values have been processed.""" + for _unused in self.values(): + pass + + def __eq__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__eq__(other) + + def __ne__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__ne__(other) + + def __or__(self, other): + # DTFIX-RELEASE: support preservation of laziness when possible like we do for list + # Both sides end up going through _proxy_or_render_lazy_value, so there's no Templar preservation needed. + # In the future this could be made more lazy when both Templar instances are the same, or if per-value Templar tracking was used. + return super().__or__(other) + + def __ror__(self, other): + # DTFIX-RELEASE: support preservation of laziness when possible like we do for list + # Both sides end up going through _proxy_or_render_lazy_value, so there's no Templar preservation needed. + # In the future this could be made more lazy when both Templar instances are the same, or if per-value Templar tracking was used. + return super().__ror__(other) + + def __deepcopy__(self, memo): + return _AnsibleLazyTemplateDict( + _LazyValueSource( + source=((copy.deepcopy(k), copy.deepcopy(v)) for k, v in super().items()), + templar=copy.deepcopy(self._templar), + lazy_options=copy.deepcopy(self._lazy_options), + ) + ) + + +@t.final # consumers of lazy collections rely heavily on the concrete types being final +class _AnsibleLazyTemplateList(_AnsibleTaggedList, _AnsibleLazyTemplateMixin): + __slots__ = _AnsibleLazyTemplateMixin._SLOTS + + def __init__(self, contents: t.Iterable | _LazyValueSource, /) -> None: + _AnsibleLazyTemplateMixin.__init__(self, contents) + + if isinstance(contents, _AnsibleLazyTemplateList): + super().__init__(list.__iter__(contents)) + elif isinstance(contents, _LazyValueSource): + super().__init__(contents.source) + else: + raise UnsupportedConstructionMethodError() + + def __getitem__(self, key: t.SupportsIndex | slice, /) -> t.Any: + if type(key) is slice: # pylint: disable=unidiomatic-typecheck + return _AnsibleLazyTemplateList(_LazyValueSource(source=super().__getitem__(key), templar=self._templar, lazy_options=self._lazy_options)) + + return self._proxy_or_render_lazy_value(key, super().__getitem__(key)) + + def __iter__(self): + for key, value in enumerate(super().__iter__()): + yield self._proxy_or_render_lazy_value(key, value) + + def pop(self, idx: t.SupportsIndex = -1, /) -> t.Any: + if not self: + raise IndexError('pop from empty list') + + try: + value = self[idx] + except IndexError: + raise IndexError('pop index out of range') + + value = self._proxy_or_render_lazy_value(_NoKeySentinel, value) + + del self[idx] + + return value + + def __str__(self): + return str(self.copy()._native_copy()) # inefficient, but avoids mutating the current instance (to make debugging practical) + + def __repr__(self): + return repr(self.copy()._native_copy()) # inefficient, but avoids mutating the current instance (to make debugging practical) + + @staticmethod + def _item_source(value: list) -> list | _LazyValueSource: + if isinstance(value, _AnsibleLazyTemplateList): + return _LazyValueSource(source=list.__iter__(value), templar=value._templar, lazy_options=value._lazy_options) + + return value + + def _yield_non_lazy_list_items(self): + """ + Delegate to the base collection iterator to yield the raw contents. + As of Python 3.13, generator functions are significantly faster than inline generator expressions. + """ + for v in list.__iter__(self): + yield v.value if type(v) is _LazyValue else v # pylint: disable=unidiomatic-typecheck + + def _non_lazy_copy(self) -> list: + return AnsibleTagHelper.tag_copy(self, self._yield_non_lazy_list_items(), value_type=list) + + @staticmethod + def _lazy_values(values: list, lazy_options: LazyOptions) -> _LazyValueSource: + return _LazyValueSource(source=(_LazyValue(v) for v in values), templar=TemplateContext.current().templar, lazy_options=lazy_options) + + @staticmethod + def _proxy_or_render_other(other: t.Any | None) -> None: + """Call `_proxy_or_render_lazy_values` if `other` is a lazy list. Used internally by comparison methods.""" + if type(other) is _AnsibleLazyTemplateList: # pylint: disable=unidiomatic-typecheck + other._proxy_or_render_lazy_values() + + def _proxy_or_render_lazy_values(self) -> None: + """Ensure all `_LazyValue` wrapped values have been processed.""" + for _unused in self: + pass + + def __eq__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__eq__(other) + + def __ne__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__ne__(other) + + def __gt__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__gt__(other) + + def __ge__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__ge__(other) + + def __lt__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__lt__(other) + + def __le__(self, other): + self._proxy_or_render_lazy_values() + self._proxy_or_render_other(other) + return super().__le__(other) + + def __contains__(self, item): + self._proxy_or_render_lazy_values() + return super().__contains__(item) + + def __reversed__(self): + for idx in range(self.__len__() - 1, -1, -1): + yield self[idx] + + def __add__(self, other): + if self._is_not_lazy_combine_candidate(other): + # When other is lazy with a different templar/options, it cannot be lazily combined with self and a plain list must be returned. + # If other is a list, de-lazify both, otherwise just let the operation fail. + + if isinstance(other, _AnsibleLazyTemplateList): + self._proxy_or_render_lazy_values() + other._proxy_or_render_lazy_values() + + return super().__add__(other) + + # For all other cases, the new list inherits our templar and all values stay lazy. + # We use list.__add__ to avoid implementing all its error behavior. + return _AnsibleLazyTemplateList(_LazyValueSource(source=super().__add__(other), templar=self._templar, lazy_options=self._lazy_options)) + + def __radd__(self, other): + if not (other_add := getattr(other, '__add__', None)): + raise TypeError(f'unsupported operand type(s) for +: {type(other).__name__!r} and {type(self).__name__!r}') from None + + return _AnsibleLazyTemplateList(_LazyValueSource(source=other_add(self), templar=self._templar, lazy_options=self._lazy_options)) + + def __mul__(self, other): + return _AnsibleLazyTemplateList(_LazyValueSource(source=super().__mul__(other), templar=self._templar, lazy_options=self._lazy_options)) + + def __rmul__(self, other): + return _AnsibleLazyTemplateList(_LazyValueSource(source=super().__rmul__(other), templar=self._templar, lazy_options=self._lazy_options)) + + def index(self, *args, **kwargs) -> int: + self._proxy_or_render_lazy_values() + return super().index(*args, **kwargs) + + def remove(self, *args, **kwargs) -> None: + self._proxy_or_render_lazy_values() + super().remove(*args, **kwargs) + + def sort(self, *args, **kwargs) -> None: + self._proxy_or_render_lazy_values() + super().sort(*args, **kwargs) + + def __deepcopy__(self, memo): + return _AnsibleLazyTemplateList( + _LazyValueSource( + source=(copy.deepcopy(v) for v in super().__iter__()), + templar=copy.deepcopy(self._templar), + lazy_options=copy.deepcopy(self._lazy_options), + ) + ) + + +@t.final # consumers of lazy collections rely heavily on the concrete types being final +class _AnsibleLazyAccessTuple(_AnsibleTaggedTuple, _AnsibleLazyTemplateMixin): + """ + A tagged tuple subclass that provides only managed access for existing lazy values. + + Since tuples are immutable, they cannot support lazy templating (which would change the tuple's value as templates were resolved). + When this type is created, each value in the source tuple is lazified: + + * template strings are templated immediately (possibly resulting in lazy containers) + * non-tuple containers are lazy-wrapped + * tuples are immediately recursively lazy-wrapped + * transformations are applied immediately + + The resulting object provides only managed access to its values (e.g., deprecation warnings, tripwires), and propagates to new lazy containers + created as a results of managed access. + """ + + # DTFIX-RELEASE: ensure we have tests that explicitly verify this behavior + + # nonempty __slots__ not supported for subtype of 'tuple' + + def __new__(cls, contents: t.Iterable | _LazyValueSource, /) -> t.Self: + if isinstance(contents, _AnsibleLazyAccessTuple): + return super().__new__(cls, tuple.__iter__(contents)) + + if isinstance(contents, _LazyValueSource): + return super().__new__(cls, contents.source) + + raise UnsupportedConstructionMethodError() + + def __init__(self, contents: t.Iterable | _LazyValueSource, /) -> None: + _AnsibleLazyTemplateMixin.__init__(self, contents) + + def __getitem__(self, key: t.SupportsIndex | slice, /) -> t.Any: + if type(key) is slice: # pylint: disable=unidiomatic-typecheck + return _AnsibleLazyAccessTuple(super().__getitem__(key)) + + value = super().__getitem__(key) + + if self._lazy_options.access: + AnsibleAccessContext.current().access(value) + + return value + + @staticmethod + def _item_source(value: tuple) -> tuple | _LazyValueSource: + if isinstance(value, _AnsibleLazyAccessTuple): + return _LazyValueSource(source=tuple.__iter__(value), templar=value._templar, lazy_options=value._lazy_options) + + return value + + @staticmethod + def _lazy_values(values: t.Any, lazy_options: LazyOptions) -> _LazyValueSource: + templar = TemplateContext.current().templar + + return _LazyValueSource(source=(templar.template(value, lazy_options=lazy_options) for value in values), templar=templar, lazy_options=lazy_options) + + def _non_lazy_copy(self) -> tuple: + return AnsibleTagHelper.tag_copy(self, self, value_type=tuple) + + def __deepcopy__(self, memo): + return _AnsibleLazyAccessTuple( + _LazyValueSource( + source=(copy.deepcopy(v) for v in super().__iter__()), + templar=copy.deepcopy(self._templar), + lazy_options=copy.deepcopy(self._lazy_options), + ) + ) + + +def lazify_container(value: t.Any) -> t.Any: + """ + If the given value is a supported container type, return its lazy version, otherwise return the value as-is. + This is used to ensure that managed access and templating occur on args and kwargs to a callable, even if they were sourced from Jinja constants. + + Since both variable access and plugin output are already lazified, this mostly affects Jinja constant containers. + However, plugins that directly invoke other plugins (e.g., `Environment.call_filter`) are another potential source of non-lazy containers. + In these cases, templating will occur for trusted templates automatically upon access. + + Sets, tuples, and dictionary keys cannot be lazy, since their correct operation requires hashability and equality. + These properties are mutually exclusive with the following lazy features: + + - managed access on encrypted strings - may raise errors on both operations when decryption fails + - managed access on markers - must raise errors on both operations + - templating - mutates values + + That leaves non-raising managed access as the only remaining feature, which is insufficient to warrant lazy support. + """ + return _AnsibleLazyTemplateMixin._try_create(value) + + +def lazify_container_args(item: tuple) -> tuple: + """Return the given args with values converted to lazy containers as needed.""" + return tuple(lazify_container(value) for value in item) + + +def lazify_container_kwargs(item: dict[str, t.Any]) -> dict[str, t.Any]: + """Return the given kwargs with values converted to lazy containers as needed.""" + return {key: lazify_container(value) for key, value in item.items()} diff --git a/lib/ansible/_internal/_templating/_marker_behaviors.py b/lib/ansible/_internal/_templating/_marker_behaviors.py new file mode 100644 index 00000000000..71df1a6e1f4 --- /dev/null +++ b/lib/ansible/_internal/_templating/_marker_behaviors.py @@ -0,0 +1,103 @@ +"""Handling of `Marker` values.""" + +from __future__ import annotations + +import abc +import contextlib +import dataclasses +import itertools +import typing as t + +from ansible.utils.display import Display + +from ._jinja_common import Marker + + +class MarkerBehavior(metaclass=abc.ABCMeta): + """Base class to support custom handling of `Marker` values encountered during concatenation or finalization.""" + + @abc.abstractmethod + def handle_marker(self, value: Marker) -> t.Any: + """Handle the given `Marker` value.""" + + +class FailingMarkerBehavior(MarkerBehavior): + """ + The default behavior when encountering a `Marker` value during concatenation or finalization. + This always raises the template-internal `MarkerError` exception. + """ + + def handle_marker(self, value: Marker) -> t.Any: + value.trip() + + +# FAIL_ON_MARKER_BEHAVIOR +# _DETONATE_MARKER_BEHAVIOR - internal singleton since it's the default and nobody should need to reference it, or make it an actual singleton +FAIL_ON_UNDEFINED: t.Final = FailingMarkerBehavior() # no sense in making many instances... + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class _MarkerTracker: + """A numbered occurrence of a `Marker` value for later conversion to a warning.""" + + number: int + value: Marker + + +class ReplacingMarkerBehavior(MarkerBehavior): + """All `Marker` values are replaced with a numbered string placeholder and the message from the value.""" + + def __init__(self) -> None: + self._trackers: list[_MarkerTracker] = [] + + def record_marker(self, value: Marker) -> t.Any: + """Assign a sequence number to the given value and record it for later generation of warnings.""" + number = len(self._trackers) + 1 + + self._trackers.append(_MarkerTracker(number=number, value=value)) + + return number + + def emit_warnings(self) -> None: + """Emit warning messages caused by Marker values, aggregated by unique template.""" + + display = Display() + grouped_templates = itertools.groupby(self._trackers, key=lambda tracker: tracker.value._marker_template_source) + + for template, items in grouped_templates: + item_list = list(items) + + msg = f'Encountered {len(item_list)} template error{"s" if len(item_list) > 1 else ""}.' + + for item in item_list: + msg += f'\nerror {item.number} - {item.value._as_message()}' + + display.warning(msg=msg, obj=template) + + @classmethod + @contextlib.contextmanager + def warning_context(cls) -> t.Generator[t.Self, None, None]: + """Collect warnings for `Marker` values and emit warnings when the context exits.""" + instance = cls() + + try: + yield instance + finally: + instance.emit_warnings() + + def handle_marker(self, value: Marker) -> t.Any: + number = self.record_marker(value) + + return f"<< error {number} - {value._as_message()} >>" + + +class RoutingMarkerBehavior(MarkerBehavior): + """Routes instances of Marker (by type reference) to another MarkerBehavior, defaulting to FailingMarkerBehavior.""" + + def __init__(self, dispatch_table: dict[type[Marker], MarkerBehavior]) -> None: + self._dispatch_table = dispatch_table + + def handle_marker(self, value: Marker) -> t.Any: + behavior = self._dispatch_table.get(type(value), FAIL_ON_UNDEFINED) + + return behavior.handle_marker(value) diff --git a/lib/ansible/_internal/_templating/_transform.py b/lib/ansible/_internal/_templating/_transform.py new file mode 100644 index 00000000000..346c646a131 --- /dev/null +++ b/lib/ansible/_internal/_templating/_transform.py @@ -0,0 +1,63 @@ +"""Runtime projections to provide template/var-visible views of objects that are not natively allowed in Ansible's type system.""" + +from __future__ import annotations + +import dataclasses +import typing as t + +from ansible.module_utils._internal import _traceback +from ansible.module_utils.common.messages import PluginInfo, ErrorSummary, WarningSummary, DeprecationSummary +from ansible.parsing.vault import EncryptedString, VaultHelper +from ansible.utils.display import Display + +from ._jinja_common import VaultExceptionMarker +from .._errors import _captured, _utils + +display = Display() + + +def plugin_info(value: PluginInfo) -> dict[str, str]: + """Render PluginInfo as a dictionary.""" + return dataclasses.asdict(value) + + +def error_summary(value: ErrorSummary) -> str: + """Render ErrorSummary as a formatted traceback for backward-compatibility with pre-2.19 TaskResult.exception.""" + return value.formatted_traceback or '(traceback unavailable)' + + +def warning_summary(value: WarningSummary) -> str: + """Render WarningSummary as a simple message string for backward-compatibility with pre-2.19 TaskResult.warnings.""" + return value._format() + + +def deprecation_summary(value: DeprecationSummary) -> dict[str, t.Any]: + """Render DeprecationSummary as dict values for backward-compatibility with pre-2.19 TaskResult.deprecations.""" + # DTFIX-RELEASE: reconsider which deprecation fields should be exposed here, taking into account that collection_name is to be deprecated + result = value._as_simple_dict() + result.pop('details') + + return result + + +def encrypted_string(value: EncryptedString) -> str | VaultExceptionMarker: + """Decrypt an encrypted string and return its value, or a VaultExceptionMarker if decryption fails.""" + try: + return value._decrypt() + except Exception as ex: + return VaultExceptionMarker( + ciphertext=VaultHelper.get_ciphertext(value, with_tags=True), + reason=_utils.get_chained_message(ex), + traceback=_traceback.maybe_extract_traceback(ex, _traceback.TracebackEvent.ERROR), + ) + + +_type_transform_mapping: dict[type, t.Callable[[t.Any], t.Any]] = { + _captured.CapturedErrorSummary: error_summary, + PluginInfo: plugin_info, + ErrorSummary: error_summary, + WarningSummary: warning_summary, + DeprecationSummary: deprecation_summary, + EncryptedString: encrypted_string, +} +"""This mapping is consulted by `Templar.template` to provide custom views of some objects.""" diff --git a/lib/ansible/_internal/_templating/_utils.py b/lib/ansible/_internal/_templating/_utils.py new file mode 100644 index 00000000000..1f77075dae7 --- /dev/null +++ b/lib/ansible/_internal/_templating/_utils.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import dataclasses +import typing as t + +from ansible.module_utils._internal import _ambient_context, _datatag + +if t.TYPE_CHECKING: + from ._engine import TemplateEngine, TemplateOptions + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class LazyOptions: + """Templating options that apply to lazy containers, which are inherited by descendent lazy containers.""" + + DEFAULT: t.ClassVar[t.Self] + """A shared instance with the default options to minimize instance creation for arg defaults.""" + SKIP_TEMPLATES: t.ClassVar[t.Self] + """A shared instance with only `template=False` set to minimize instance creation for arg defaults.""" + SKIP_TEMPLATES_AND_ACCESS: t.ClassVar[t.Self] + """A shared instance with both `template=False` and `access=False` set to minimize instance creation for arg defaults.""" + + template: bool = True + """Enable/disable templating.""" + + access: bool = True + """Enable/disables access calls.""" + + unmask_type_names: frozenset[str] = frozenset() + """Disables template transformations for the provided type names.""" + + +LazyOptions.DEFAULT = LazyOptions() +LazyOptions.SKIP_TEMPLATES = LazyOptions(template=False) +LazyOptions.SKIP_TEMPLATES_AND_ACCESS = LazyOptions(template=False, access=False) + + +class TemplateContext(_ambient_context.AmbientContextBase): + def __init__( + self, + *, + template_value: t.Any, + templar: TemplateEngine, + options: TemplateOptions, + stop_on_template: bool = False, + _render_jinja_const_template: bool = False, + ): + self._template_value = template_value + self._templar = templar + self._options = options + self._stop_on_template = stop_on_template + self._parent_ctx = TemplateContext.current(optional=True) + self._render_jinja_const_template = _render_jinja_const_template + + @property + def is_top_level(self) -> bool: + return not self._parent_ctx + + @property + def template_value(self) -> t.Any: + return self._template_value + + @property + def templar(self) -> TemplateEngine: + return self._templar + + @property + def options(self) -> TemplateOptions: + return self._options + + @property + def stop_on_template(self) -> bool: + return self._stop_on_template + + +class _OmitType: + """ + A placeholder singleton used to dynamically omit items from a dict/list/tuple/set when the value is `Omit`. + + The `Omit` singleton is accessible from all Ansible templating contexts via the Jinja global name `omit`. + The `Omit` placeholder value will be visible to Jinja plugins during templating. + Jinja plugins requiring omit behavior are responsible for handling encountered `Omit` values. + `Omit` values remaining in template results will be automatically dropped during template finalization. + When a finalized template renders to a scalar `Omit`, `AnsibleValueOmittedError` will be raised. + Passing a value other than `Omit` for `value_for_omit` to the `template` call allows that value to be substituted instead of raising. + """ + + __slots__ = () + + def __new__(cls): + return Omit + + def __repr__(self): + return "<>" + + +Omit = object.__new__(_OmitType) + +_datatag._untaggable_types.add(_OmitType) + + +# DTFIX-RELEASE: review these type sets to ensure they're not overly permissive/dynamic +IGNORE_SCALAR_VAR_TYPES = {value for value in _datatag._ANSIBLE_ALLOWED_SCALAR_VAR_TYPES if not issubclass(value, str)} + +PASS_THROUGH_SCALAR_VAR_TYPES = _datatag._ANSIBLE_ALLOWED_SCALAR_VAR_TYPES | { + _OmitType, # allow pass through of omit for later handling after top-level finalize completes +} diff --git a/lib/ansible/_internal/_wrapt.py b/lib/ansible/_internal/_wrapt.py new file mode 100644 index 00000000000..d493baaa717 --- /dev/null +++ b/lib/ansible/_internal/_wrapt.py @@ -0,0 +1,1052 @@ +# Copyright (c) 2013-2023, Graham Dumpleton +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +# copied from https://github.com/GrahamDumpleton/wrapt/blob/1.15.0/src/wrapt/wrappers.py + +# LOCAL PATCHES: +# - disabled optional relative import of the _wrappers C extension; we shouldn't need it + +from __future__ import annotations + +# The following makes it easier for us to script updates of the bundled code +_BUNDLED_METADATA = {"pypi_name": "wrapt", "version": "1.15.0"} + +import os +import sys +import functools +import operator +import weakref +import inspect + +PY2 = sys.version_info[0] == 2 + +if PY2: + string_types = basestring, +else: + string_types = str, + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + return meta("NewBase", bases, {}) + +class _ObjectProxyMethods(object): + + # We use properties to override the values of __module__ and + # __doc__. If we add these in ObjectProxy, the derived class + # __dict__ will still be setup to have string variants of these + # attributes and the rules of descriptors means that they appear to + # take precedence over the properties in the base class. To avoid + # that, we copy the properties into the derived class type itself + # via a meta class. In that way the properties will always take + # precedence. + + @property + def __module__(self): + return self.__wrapped__.__module__ + + @__module__.setter + def __module__(self, value): + self.__wrapped__.__module__ = value + + @property + def __doc__(self): + return self.__wrapped__.__doc__ + + @__doc__.setter + def __doc__(self, value): + self.__wrapped__.__doc__ = value + + # We similar use a property for __dict__. We need __dict__ to be + # explicit to ensure that vars() works as expected. + + @property + def __dict__(self): + return self.__wrapped__.__dict__ + + # Need to also propagate the special __weakref__ attribute for case + # where decorating classes which will define this. If do not define + # it and use a function like inspect.getmembers() on a decorator + # class it will fail. This can't be in the derived classes. + + @property + def __weakref__(self): + return self.__wrapped__.__weakref__ + +class _ObjectProxyMetaType(type): + def __new__(cls, name, bases, dictionary): + # Copy our special properties into the class so that they + # always take precedence over attributes of the same name added + # during construction of a derived class. This is to save + # duplicating the implementation for them in all derived classes. + + dictionary.update(vars(_ObjectProxyMethods)) + + return type.__new__(cls, name, bases, dictionary) + +class ObjectProxy(with_metaclass(_ObjectProxyMetaType)): + + __slots__ = '__wrapped__' + + def __init__(self, wrapped): + object.__setattr__(self, '__wrapped__', wrapped) + + # Python 3.2+ has the __qualname__ attribute, but it does not + # allow it to be overridden using a property and it must instead + # be an actual string object instead. + + try: + object.__setattr__(self, '__qualname__', wrapped.__qualname__) + except AttributeError: + pass + + # Python 3.10 onwards also does not allow itself to be overridden + # using a property and it must instead be set explicitly. + + try: + object.__setattr__(self, '__annotations__', wrapped.__annotations__) + except AttributeError: + pass + + @property + def __name__(self): + return self.__wrapped__.__name__ + + @__name__.setter + def __name__(self, value): + self.__wrapped__.__name__ = value + + @property + def __class__(self): + return self.__wrapped__.__class__ + + @__class__.setter + def __class__(self, value): + self.__wrapped__.__class__ = value + + def __dir__(self): + return dir(self.__wrapped__) + + def __str__(self): + return str(self.__wrapped__) + + if not PY2: + def __bytes__(self): + return bytes(self.__wrapped__) + + def __repr__(self): + return '<{} at 0x{:x} for {} at 0x{:x}>'.format( + type(self).__name__, id(self), + type(self.__wrapped__).__name__, + id(self.__wrapped__)) + + def __reversed__(self): + return reversed(self.__wrapped__) + + if not PY2: + def __round__(self): + return round(self.__wrapped__) + + if sys.hexversion >= 0x03070000: + def __mro_entries__(self, bases): + return (self.__wrapped__,) + + def __lt__(self, other): + return self.__wrapped__ < other + + def __le__(self, other): + return self.__wrapped__ <= other + + def __eq__(self, other): + return self.__wrapped__ == other + + def __ne__(self, other): + return self.__wrapped__ != other + + def __gt__(self, other): + return self.__wrapped__ > other + + def __ge__(self, other): + return self.__wrapped__ >= other + + def __hash__(self): + return hash(self.__wrapped__) + + def __nonzero__(self): + return bool(self.__wrapped__) + + def __bool__(self): + return bool(self.__wrapped__) + + def __setattr__(self, name, value): + if name.startswith('_self_'): + object.__setattr__(self, name, value) + + elif name == '__wrapped__': + object.__setattr__(self, name, value) + try: + object.__delattr__(self, '__qualname__') + except AttributeError: + pass + try: + object.__setattr__(self, '__qualname__', value.__qualname__) + except AttributeError: + pass + try: + object.__delattr__(self, '__annotations__') + except AttributeError: + pass + try: + object.__setattr__(self, '__annotations__', value.__annotations__) + except AttributeError: + pass + + elif name == '__qualname__': + setattr(self.__wrapped__, name, value) + object.__setattr__(self, name, value) + + elif name == '__annotations__': + setattr(self.__wrapped__, name, value) + object.__setattr__(self, name, value) + + elif hasattr(type(self), name): + object.__setattr__(self, name, value) + + else: + setattr(self.__wrapped__, name, value) + + def __getattr__(self, name): + # If we are being to lookup '__wrapped__' then the + # '__init__()' method cannot have been called. + + if name == '__wrapped__': + raise ValueError('wrapper has not been initialised') + + return getattr(self.__wrapped__, name) + + def __delattr__(self, name): + if name.startswith('_self_'): + object.__delattr__(self, name) + + elif name == '__wrapped__': + raise TypeError('__wrapped__ must be an object') + + elif name == '__qualname__': + object.__delattr__(self, name) + delattr(self.__wrapped__, name) + + elif hasattr(type(self), name): + object.__delattr__(self, name) + + else: + delattr(self.__wrapped__, name) + + def __add__(self, other): + return self.__wrapped__ + other + + def __sub__(self, other): + return self.__wrapped__ - other + + def __mul__(self, other): + return self.__wrapped__ * other + + def __div__(self, other): + return operator.div(self.__wrapped__, other) + + def __truediv__(self, other): + return operator.truediv(self.__wrapped__, other) + + def __floordiv__(self, other): + return self.__wrapped__ // other + + def __mod__(self, other): + return self.__wrapped__ % other + + def __divmod__(self, other): + return divmod(self.__wrapped__, other) + + def __pow__(self, other, *args): + return pow(self.__wrapped__, other, *args) + + def __lshift__(self, other): + return self.__wrapped__ << other + + def __rshift__(self, other): + return self.__wrapped__ >> other + + def __and__(self, other): + return self.__wrapped__ & other + + def __xor__(self, other): + return self.__wrapped__ ^ other + + def __or__(self, other): + return self.__wrapped__ | other + + def __radd__(self, other): + return other + self.__wrapped__ + + def __rsub__(self, other): + return other - self.__wrapped__ + + def __rmul__(self, other): + return other * self.__wrapped__ + + def __rdiv__(self, other): + return operator.div(other, self.__wrapped__) + + def __rtruediv__(self, other): + return operator.truediv(other, self.__wrapped__) + + def __rfloordiv__(self, other): + return other // self.__wrapped__ + + def __rmod__(self, other): + return other % self.__wrapped__ + + def __rdivmod__(self, other): + return divmod(other, self.__wrapped__) + + def __rpow__(self, other, *args): + return pow(other, self.__wrapped__, *args) + + def __rlshift__(self, other): + return other << self.__wrapped__ + + def __rrshift__(self, other): + return other >> self.__wrapped__ + + def __rand__(self, other): + return other & self.__wrapped__ + + def __rxor__(self, other): + return other ^ self.__wrapped__ + + def __ror__(self, other): + return other | self.__wrapped__ + + def __iadd__(self, other): + self.__wrapped__ += other + return self + + def __isub__(self, other): + self.__wrapped__ -= other + return self + + def __imul__(self, other): + self.__wrapped__ *= other + return self + + def __idiv__(self, other): + self.__wrapped__ = operator.idiv(self.__wrapped__, other) + return self + + def __itruediv__(self, other): + self.__wrapped__ = operator.itruediv(self.__wrapped__, other) + return self + + def __ifloordiv__(self, other): + self.__wrapped__ //= other + return self + + def __imod__(self, other): + self.__wrapped__ %= other + return self + + def __ipow__(self, other): + self.__wrapped__ **= other + return self + + def __ilshift__(self, other): + self.__wrapped__ <<= other + return self + + def __irshift__(self, other): + self.__wrapped__ >>= other + return self + + def __iand__(self, other): + self.__wrapped__ &= other + return self + + def __ixor__(self, other): + self.__wrapped__ ^= other + return self + + def __ior__(self, other): + self.__wrapped__ |= other + return self + + def __neg__(self): + return -self.__wrapped__ + + def __pos__(self): + return +self.__wrapped__ + + def __abs__(self): + return abs(self.__wrapped__) + + def __invert__(self): + return ~self.__wrapped__ + + def __int__(self): + return int(self.__wrapped__) + + def __long__(self): + return long(self.__wrapped__) + + def __float__(self): + return float(self.__wrapped__) + + def __complex__(self): + return complex(self.__wrapped__) + + def __oct__(self): + return oct(self.__wrapped__) + + def __hex__(self): + return hex(self.__wrapped__) + + def __index__(self): + return operator.index(self.__wrapped__) + + def __len__(self): + return len(self.__wrapped__) + + def __contains__(self, value): + return value in self.__wrapped__ + + def __getitem__(self, key): + return self.__wrapped__[key] + + def __setitem__(self, key, value): + self.__wrapped__[key] = value + + def __delitem__(self, key): + del self.__wrapped__[key] + + def __getslice__(self, i, j): + return self.__wrapped__[i:j] + + def __setslice__(self, i, j, value): + self.__wrapped__[i:j] = value + + def __delslice__(self, i, j): + del self.__wrapped__[i:j] + + def __enter__(self): + return self.__wrapped__.__enter__() + + def __exit__(self, *args, **kwargs): + return self.__wrapped__.__exit__(*args, **kwargs) + + def __iter__(self): + return iter(self.__wrapped__) + + def __copy__(self): + raise NotImplementedError('object proxy must define __copy__()') + + def __deepcopy__(self, memo): + raise NotImplementedError('object proxy must define __deepcopy__()') + + def __reduce__(self): + raise NotImplementedError( + 'object proxy must define __reduce_ex__()') + + def __reduce_ex__(self, protocol): + raise NotImplementedError( + 'object proxy must define __reduce_ex__()') + +class CallableObjectProxy(ObjectProxy): + + def __call__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + return self.__wrapped__(*args, **kwargs) + +class PartialCallableObjectProxy(ObjectProxy): + + def __init__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + if len(args) < 1: + raise TypeError('partial type takes at least one argument') + + wrapped, args = args[0], args[1:] + + if not callable(wrapped): + raise TypeError('the first argument must be callable') + + super(PartialCallableObjectProxy, self).__init__(wrapped) + + self._self_args = args + self._self_kwargs = kwargs + + def __call__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + _args = self._self_args + args + + _kwargs = dict(self._self_kwargs) + _kwargs.update(kwargs) + + return self.__wrapped__(*_args, **_kwargs) + +class _FunctionWrapperBase(ObjectProxy): + + __slots__ = ('_self_instance', '_self_wrapper', '_self_enabled', + '_self_binding', '_self_parent') + + def __init__(self, wrapped, instance, wrapper, enabled=None, + binding='function', parent=None): + + super(_FunctionWrapperBase, self).__init__(wrapped) + + object.__setattr__(self, '_self_instance', instance) + object.__setattr__(self, '_self_wrapper', wrapper) + object.__setattr__(self, '_self_enabled', enabled) + object.__setattr__(self, '_self_binding', binding) + object.__setattr__(self, '_self_parent', parent) + + def __get__(self, instance, owner): + # This method is actually doing double duty for both unbound and + # bound derived wrapper classes. It should possibly be broken up + # and the distinct functionality moved into the derived classes. + # Can't do that straight away due to some legacy code which is + # relying on it being here in this base class. + # + # The distinguishing attribute which determines whether we are + # being called in an unbound or bound wrapper is the parent + # attribute. If binding has never occurred, then the parent will + # be None. + # + # First therefore, is if we are called in an unbound wrapper. In + # this case we perform the binding. + # + # We have one special case to worry about here. This is where we + # are decorating a nested class. In this case the wrapped class + # would not have a __get__() method to call. In that case we + # simply return self. + # + # Note that we otherwise still do binding even if instance is + # None and accessing an unbound instance method from a class. + # This is because we need to be able to later detect that + # specific case as we will need to extract the instance from the + # first argument of those passed in. + + if self._self_parent is None: + if not inspect.isclass(self.__wrapped__): + descriptor = self.__wrapped__.__get__(instance, owner) + + return self.__bound_function_wrapper__(descriptor, instance, + self._self_wrapper, self._self_enabled, + self._self_binding, self) + + return self + + # Now we have the case of binding occurring a second time on what + # was already a bound function. In this case we would usually + # return ourselves again. This mirrors what Python does. + # + # The special case this time is where we were originally bound + # with an instance of None and we were likely an instance + # method. In that case we rebind against the original wrapped + # function from the parent again. + + if self._self_instance is None and self._self_binding == 'function': + descriptor = self._self_parent.__wrapped__.__get__( + instance, owner) + + return self._self_parent.__bound_function_wrapper__( + descriptor, instance, self._self_wrapper, + self._self_enabled, self._self_binding, + self._self_parent) + + return self + + def __call__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + # If enabled has been specified, then evaluate it at this point + # and if the wrapper is not to be executed, then simply return + # the bound function rather than a bound wrapper for the bound + # function. When evaluating enabled, if it is callable we call + # it, otherwise we evaluate it as a boolean. + + if self._self_enabled is not None: + if callable(self._self_enabled): + if not self._self_enabled(): + return self.__wrapped__(*args, **kwargs) + elif not self._self_enabled: + return self.__wrapped__(*args, **kwargs) + + # This can occur where initial function wrapper was applied to + # a function that was already bound to an instance. In that case + # we want to extract the instance from the function and use it. + + if self._self_binding in ('function', 'classmethod'): + if self._self_instance is None: + instance = getattr(self.__wrapped__, '__self__', None) + if instance is not None: + return self._self_wrapper(self.__wrapped__, instance, + args, kwargs) + + # This is generally invoked when the wrapped function is being + # called as a normal function and is not bound to a class as an + # instance method. This is also invoked in the case where the + # wrapped function was a method, but this wrapper was in turn + # wrapped using the staticmethod decorator. + + return self._self_wrapper(self.__wrapped__, self._self_instance, + args, kwargs) + + def __set_name__(self, owner, name): + # This is a special method use to supply information to + # descriptors about what the name of variable in a class + # definition is. Not wanting to add this to ObjectProxy as not + # sure of broader implications of doing that. Thus restrict to + # FunctionWrapper used by decorators. + + if hasattr(self.__wrapped__, "__set_name__"): + self.__wrapped__.__set_name__(owner, name) + + def __instancecheck__(self, instance): + # This is a special method used by isinstance() to make checks + # instance of the `__wrapped__`. + return isinstance(instance, self.__wrapped__) + + def __subclasscheck__(self, subclass): + # This is a special method used by issubclass() to make checks + # about inheritance of classes. We need to upwrap any object + # proxy. Not wanting to add this to ObjectProxy as not sure of + # broader implications of doing that. Thus restrict to + # FunctionWrapper used by decorators. + + if hasattr(subclass, "__wrapped__"): + return issubclass(subclass.__wrapped__, self.__wrapped__) + else: + return issubclass(subclass, self.__wrapped__) + +class BoundFunctionWrapper(_FunctionWrapperBase): + + def __call__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + # If enabled has been specified, then evaluate it at this point + # and if the wrapper is not to be executed, then simply return + # the bound function rather than a bound wrapper for the bound + # function. When evaluating enabled, if it is callable we call + # it, otherwise we evaluate it as a boolean. + + if self._self_enabled is not None: + if callable(self._self_enabled): + if not self._self_enabled(): + return self.__wrapped__(*args, **kwargs) + elif not self._self_enabled: + return self.__wrapped__(*args, **kwargs) + + # We need to do things different depending on whether we are + # likely wrapping an instance method vs a static method or class + # method. + + if self._self_binding == 'function': + if self._self_instance is None: + # This situation can occur where someone is calling the + # instancemethod via the class type and passing the instance + # as the first argument. We need to shift the args before + # making the call to the wrapper and effectively bind the + # instance to the wrapped function using a partial so the + # wrapper doesn't see anything as being different. + + if not args: + raise TypeError('missing 1 required positional argument') + + instance, args = args[0], args[1:] + wrapped = PartialCallableObjectProxy(self.__wrapped__, instance) + return self._self_wrapper(wrapped, instance, args, kwargs) + + return self._self_wrapper(self.__wrapped__, self._self_instance, + args, kwargs) + + else: + # As in this case we would be dealing with a classmethod or + # staticmethod, then _self_instance will only tell us whether + # when calling the classmethod or staticmethod they did it via an + # instance of the class it is bound to and not the case where + # done by the class type itself. We thus ignore _self_instance + # and use the __self__ attribute of the bound function instead. + # For a classmethod, this means instance will be the class type + # and for a staticmethod it will be None. This is probably the + # more useful thing we can pass through even though we loose + # knowledge of whether they were called on the instance vs the + # class type, as it reflects what they have available in the + # decoratored function. + + instance = getattr(self.__wrapped__, '__self__', None) + + return self._self_wrapper(self.__wrapped__, instance, args, + kwargs) + +class FunctionWrapper(_FunctionWrapperBase): + + __bound_function_wrapper__ = BoundFunctionWrapper + + def __init__(self, wrapped, wrapper, enabled=None): + # What it is we are wrapping here could be anything. We need to + # try and detect specific cases though. In particular, we need + # to detect when we are given something that is a method of a + # class. Further, we need to know when it is likely an instance + # method, as opposed to a class or static method. This can + # become problematic though as there isn't strictly a fool proof + # method of knowing. + # + # The situations we could encounter when wrapping a method are: + # + # 1. The wrapper is being applied as part of a decorator which + # is a part of the class definition. In this case what we are + # given is the raw unbound function, classmethod or staticmethod + # wrapper objects. + # + # The problem here is that we will not know we are being applied + # in the context of the class being set up. This becomes + # important later for the case of an instance method, because in + # that case we just see it as a raw function and can't + # distinguish it from wrapping a normal function outside of + # a class context. + # + # 2. The wrapper is being applied when performing monkey + # patching of the class type afterwards and the method to be + # wrapped was retrieved direct from the __dict__ of the class + # type. This is effectively the same as (1) above. + # + # 3. The wrapper is being applied when performing monkey + # patching of the class type afterwards and the method to be + # wrapped was retrieved from the class type. In this case + # binding will have been performed where the instance against + # which the method is bound will be None at that point. + # + # This case is a problem because we can no longer tell if the + # method was a static method, plus if using Python3, we cannot + # tell if it was an instance method as the concept of an + # unnbound method no longer exists. + # + # 4. The wrapper is being applied when performing monkey + # patching of an instance of a class. In this case binding will + # have been perfomed where the instance was not None. + # + # This case is a problem because we can no longer tell if the + # method was a static method. + # + # Overall, the best we can do is look at the original type of the + # object which was wrapped prior to any binding being done and + # see if it is an instance of classmethod or staticmethod. In + # the case where other decorators are between us and them, if + # they do not propagate the __class__ attribute so that the + # isinstance() checks works, then likely this will do the wrong + # thing where classmethod and staticmethod are used. + # + # Since it is likely to be very rare that anyone even puts + # decorators around classmethod and staticmethod, likelihood of + # that being an issue is very small, so we accept it and suggest + # that those other decorators be fixed. It is also only an issue + # if a decorator wants to actually do things with the arguments. + # + # As to not being able to identify static methods properly, we + # just hope that that isn't something people are going to want + # to wrap, or if they do suggest they do it the correct way by + # ensuring that it is decorated in the class definition itself, + # or patch it in the __dict__ of the class type. + # + # So to get the best outcome we can, whenever we aren't sure what + # it is, we label it as a 'function'. If it was already bound and + # that is rebound later, we assume that it will be an instance + # method and try an cope with the possibility that the 'self' + # argument it being passed as an explicit argument and shuffle + # the arguments around to extract 'self' for use as the instance. + + if isinstance(wrapped, classmethod): + binding = 'classmethod' + + elif isinstance(wrapped, staticmethod): + binding = 'staticmethod' + + elif hasattr(wrapped, '__self__'): + if inspect.isclass(wrapped.__self__): + binding = 'classmethod' + else: + binding = 'function' + + else: + binding = 'function' + + super(FunctionWrapper, self).__init__(wrapped, None, wrapper, + enabled, binding) + +# disabled support for native extension; we likely don't need it +# try: +# if not os.environ.get('WRAPT_DISABLE_EXTENSIONS'): +# from ._wrappers import (ObjectProxy, CallableObjectProxy, +# PartialCallableObjectProxy, FunctionWrapper, +# BoundFunctionWrapper, _FunctionWrapperBase) +# except ImportError: +# pass + +# Helper functions for applying wrappers to existing functions. + +def resolve_path(module, name): + if isinstance(module, string_types): + __import__(module) + module = sys.modules[module] + + parent = module + + path = name.split('.') + attribute = path[0] + + # We can't just always use getattr() because in doing + # that on a class it will cause binding to occur which + # will complicate things later and cause some things not + # to work. For the case of a class we therefore access + # the __dict__ directly. To cope though with the wrong + # class being given to us, or a method being moved into + # a base class, we need to walk the class hierarchy to + # work out exactly which __dict__ the method was defined + # in, as accessing it from __dict__ will fail if it was + # not actually on the class given. Fallback to using + # getattr() if we can't find it. If it truly doesn't + # exist, then that will fail. + + def lookup_attribute(parent, attribute): + if inspect.isclass(parent): + for cls in inspect.getmro(parent): + if attribute in vars(cls): + return vars(cls)[attribute] + else: + return getattr(parent, attribute) + else: + return getattr(parent, attribute) + + original = lookup_attribute(parent, attribute) + + for attribute in path[1:]: + parent = original + original = lookup_attribute(parent, attribute) + + return (parent, attribute, original) + +def apply_patch(parent, attribute, replacement): + setattr(parent, attribute, replacement) + +def wrap_object(module, name, factory, args=(), kwargs={}): + (parent, attribute, original) = resolve_path(module, name) + wrapper = factory(original, *args, **kwargs) + apply_patch(parent, attribute, wrapper) + return wrapper + +# Function for applying a proxy object to an attribute of a class +# instance. The wrapper works by defining an attribute of the same name +# on the class which is a descriptor and which intercepts access to the +# instance attribute. Note that this cannot be used on attributes which +# are themselves defined by a property object. + +class AttributeWrapper(object): + + def __init__(self, attribute, factory, args, kwargs): + self.attribute = attribute + self.factory = factory + self.args = args + self.kwargs = kwargs + + def __get__(self, instance, owner): + value = instance.__dict__[self.attribute] + return self.factory(value, *self.args, **self.kwargs) + + def __set__(self, instance, value): + instance.__dict__[self.attribute] = value + + def __delete__(self, instance): + del instance.__dict__[self.attribute] + +def wrap_object_attribute(module, name, factory, args=(), kwargs={}): + path, attribute = name.rsplit('.', 1) + parent = resolve_path(module, path)[2] + wrapper = AttributeWrapper(attribute, factory, args, kwargs) + apply_patch(parent, attribute, wrapper) + return wrapper + +# Functions for creating a simple decorator using a FunctionWrapper, +# plus short cut functions for applying wrappers to functions. These are +# for use when doing monkey patching. For a more featured way of +# creating decorators see the decorator decorator instead. + +def function_wrapper(wrapper): + def _wrapper(wrapped, instance, args, kwargs): + target_wrapped = args[0] + if instance is None: + target_wrapper = wrapper + elif inspect.isclass(instance): + target_wrapper = wrapper.__get__(None, instance) + else: + target_wrapper = wrapper.__get__(instance, type(instance)) + return FunctionWrapper(target_wrapped, target_wrapper) + return FunctionWrapper(wrapper, _wrapper) + +def wrap_function_wrapper(module, name, wrapper): + return wrap_object(module, name, FunctionWrapper, (wrapper,)) + +def patch_function_wrapper(module, name): + def _wrapper(wrapper): + return wrap_object(module, name, FunctionWrapper, (wrapper,)) + return _wrapper + +def transient_function_wrapper(module, name): + def _decorator(wrapper): + def _wrapper(wrapped, instance, args, kwargs): + target_wrapped = args[0] + if instance is None: + target_wrapper = wrapper + elif inspect.isclass(instance): + target_wrapper = wrapper.__get__(None, instance) + else: + target_wrapper = wrapper.__get__(instance, type(instance)) + def _execute(wrapped, instance, args, kwargs): + (parent, attribute, original) = resolve_path(module, name) + replacement = FunctionWrapper(original, target_wrapper) + setattr(parent, attribute, replacement) + try: + return wrapped(*args, **kwargs) + finally: + setattr(parent, attribute, original) + return FunctionWrapper(target_wrapped, _execute) + return FunctionWrapper(wrapper, _wrapper) + return _decorator + +# A weak function proxy. This will work on instance methods, class +# methods, static methods and regular functions. Special treatment is +# needed for the method types because the bound method is effectively a +# transient object and applying a weak reference to one will immediately +# result in it being destroyed and the weakref callback called. The weak +# reference is therefore applied to the instance the method is bound to +# and the original function. The function is then rebound at the point +# of a call via the weak function proxy. + +def _weak_function_proxy_callback(ref, proxy, callback): + if proxy._self_expired: + return + + proxy._self_expired = True + + # This could raise an exception. We let it propagate back and let + # the weakref.proxy() deal with it, at which point it generally + # prints out a short error message direct to stderr and keeps going. + + if callback is not None: + callback(proxy) + +class WeakFunctionProxy(ObjectProxy): + + __slots__ = ('_self_expired', '_self_instance') + + def __init__(self, wrapped, callback=None): + # We need to determine if the wrapped function is actually a + # bound method. In the case of a bound method, we need to keep a + # reference to the original unbound function and the instance. + # This is necessary because if we hold a reference to the bound + # function, it will be the only reference and given it is a + # temporary object, it will almost immediately expire and + # the weakref callback triggered. So what is done is that we + # hold a reference to the instance and unbound function and + # when called bind the function to the instance once again and + # then call it. Note that we avoid using a nested function for + # the callback here so as not to cause any odd reference cycles. + + _callback = callback and functools.partial( + _weak_function_proxy_callback, proxy=self, + callback=callback) + + self._self_expired = False + + if isinstance(wrapped, _FunctionWrapperBase): + self._self_instance = weakref.ref(wrapped._self_instance, + _callback) + + if wrapped._self_parent is not None: + super(WeakFunctionProxy, self).__init__( + weakref.proxy(wrapped._self_parent, _callback)) + + else: + super(WeakFunctionProxy, self).__init__( + weakref.proxy(wrapped, _callback)) + + return + + try: + self._self_instance = weakref.ref(wrapped.__self__, _callback) + + super(WeakFunctionProxy, self).__init__( + weakref.proxy(wrapped.__func__, _callback)) + + except AttributeError: + self._self_instance = None + + super(WeakFunctionProxy, self).__init__( + weakref.proxy(wrapped, _callback)) + + def __call__(*args, **kwargs): + def _unpack_self(self, *args): + return self, args + + self, args = _unpack_self(*args) + + # We perform a boolean check here on the instance and wrapped + # function as that will trigger the reference error prior to + # calling if the reference had expired. + + instance = self._self_instance and self._self_instance() + function = self.__wrapped__ and self.__wrapped__ + + # If the wrapped function was originally a bound function, for + # which we retained a reference to the instance and the unbound + # function we need to rebind the function and then call it. If + # not just called the wrapped function. + + if instance is None: + return self.__wrapped__(*args, **kwargs) + + return function.__get__(instance, type(instance))(*args, **kwargs) \ No newline at end of file diff --git a/lib/ansible/_internal/_yaml/__init__.py b/lib/ansible/_internal/_yaml/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/_internal/_yaml/_constructor.py b/lib/ansible/_internal/_yaml/_constructor.py new file mode 100644 index 00000000000..dd72d37de32 --- /dev/null +++ b/lib/ansible/_internal/_yaml/_constructor.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import abc +import copy +import typing as t + +from yaml import Node +from yaml.constructor import SafeConstructor +from yaml.resolver import BaseResolver + +from ansible import constants as C +from ansible.module_utils.common.text.converters import to_text +from ansible.module_utils._internal._datatag import AnsibleTagHelper +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate +from ansible.parsing.vault import EncryptedString +from ansible.utils.display import Display + +from ._errors import AnsibleConstructorError + +display = Display() + +_TRUSTED_AS_TEMPLATE: t.Final[TrustedAsTemplate] = TrustedAsTemplate() + + +class _BaseConstructor(SafeConstructor, metaclass=abc.ABCMeta): + """Base class for Ansible YAML constructors.""" + + @classmethod + @abc.abstractmethod + def _register_constructors(cls) -> None: + """Method used to register constructors to derived types during class initialization.""" + + def __init_subclass__(cls, **kwargs) -> None: + """Initialization for derived types.""" + cls._register_constructors() + + +class AnsibleInstrumentedConstructor(_BaseConstructor): + """Ansible constructor which supports Ansible custom behavior such as `Origin` tagging, but no Ansible-specific YAML tags.""" + + name: t.Any # provided by the YAML parser, which retrieves it from the stream + + def __init__(self, origin: Origin, trusted_as_template: bool) -> None: + if not origin.line_num: + origin = origin.replace(line_num=1) + + self._origin = origin + self._trusted_as_template = trusted_as_template + self._duplicate_key_mode = C.config.get_config_value('DUPLICATE_YAML_DICT_KEY') + + super().__init__() + + @property + def trusted_as_template(self) -> bool: + return self._trusted_as_template + + def construct_yaml_map(self, node): + data = self._node_position_info(node).tag({}) # always an ordered dictionary on py3.7+ + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_mapping(self, node, deep=False): + # Delegate to built-in implementation to construct the mapping. + # This is done before checking for duplicates to leverage existing error checking on the input node. + mapping = super().construct_mapping(node, deep) + keys = set() + + # Now that the node is known to be a valid mapping, handle any duplicate keys. + for key_node, _value_node in node.value: + if (key := self.construct_object(key_node, deep=deep)) in keys: + msg = f'Found duplicate mapping key {key!r}.' + + if self._duplicate_key_mode == 'error': + raise AnsibleConstructorError(problem=msg, problem_mark=key_node.start_mark) + + if self._duplicate_key_mode == 'warn': + display.warning(msg=msg, obj=key, help_text='Using last defined value only.') + + keys.add(key) + + return mapping + + def construct_yaml_int(self, node): + value = super().construct_yaml_int(node) + return self._node_position_info(node).tag(value) + + def construct_yaml_float(self, node): + value = super().construct_yaml_float(node) + return self._node_position_info(node).tag(value) + + def construct_yaml_timestamp(self, node): + value = super().construct_yaml_timestamp(node) + return self._node_position_info(node).tag(value) + + def construct_yaml_omap(self, node): + origin = self._node_position_info(node) + display.deprecated( + msg='Use of the YAML `!!omap` tag is deprecated.', + version='2.23', + obj=origin, + help_text='Use a standard mapping instead, as key order is always preserved.', + ) + items = list(super().construct_yaml_omap(node))[0] + items = [origin.tag(item) for item in items] + yield origin.tag(items) + + def construct_yaml_pairs(self, node): + origin = self._node_position_info(node) + display.deprecated( + msg='Use of the YAML `!!pairs` tag is deprecated.', + version='2.23', + obj=origin, + help_text='Use a standard mapping instead.', + ) + items = list(super().construct_yaml_pairs(node))[0] + items = [origin.tag(item) for item in items] + yield origin.tag(items) + + def construct_yaml_str(self, node): + # Override the default string handling function + # to always return unicode objects + # DTFIX-FUTURE: is this to_text conversion still necessary under Py3? + value = to_text(self.construct_scalar(node)) + + tags = [self._node_position_info(node)] + + if self.trusted_as_template: + # NB: since we're not context aware, this will happily add trust to dictionary keys; this is actually necessary for + # certain backward compat scenarios, though might be accomplished in other ways if we wanted to avoid trusting keys in + # the general scenario + tags.append(_TRUSTED_AS_TEMPLATE) + + return AnsibleTagHelper.tag(value, tags) + + def construct_yaml_binary(self, node): + value = super().construct_yaml_binary(node) + + return AnsibleTagHelper.tag(value, self._node_position_info(node)) + + def construct_yaml_set(self, node): + data = AnsibleTagHelper.tag(set(), self._node_position_info(node)) + yield data + value = self.construct_mapping(node) + data.update(value) + + def construct_yaml_seq(self, node): + data = self._node_position_info(node).tag([]) + yield data + data.extend(self.construct_sequence(node)) + + def _resolve_and_construct_object(self, node): + # use a copied node to avoid mutating existing node and tripping the recursion check in construct_object + copied_node = copy.copy(node) + # repeat implicit resolution process to determine the proper tag for the value in the unsafe node + copied_node.tag = t.cast(BaseResolver, self).resolve(type(node), node.value, (True, False)) + + # re-entrant call using the correct tag + # non-deferred construction of hierarchical nodes so the result is a fully realized object, and so our stateful unsafe propagation behavior works + return self.construct_object(copied_node, deep=True) + + def _node_position_info(self, node) -> Origin: + # the line number where the previous token has ended (plus empty lines) + # Add one so that the first line is line 1 rather than line 0 + return self._origin.replace(line_num=node.start_mark.line + self._origin.line_num, col_num=node.start_mark.column + 1) + + @classmethod + def _register_constructors(cls) -> None: + constructors: dict[str, t.Callable] = { + 'tag:yaml.org,2002:binary': cls.construct_yaml_binary, + 'tag:yaml.org,2002:float': cls.construct_yaml_float, + 'tag:yaml.org,2002:int': cls.construct_yaml_int, + 'tag:yaml.org,2002:map': cls.construct_yaml_map, + 'tag:yaml.org,2002:omap': cls.construct_yaml_omap, + 'tag:yaml.org,2002:pairs': cls.construct_yaml_pairs, + 'tag:yaml.org,2002:python/dict': cls.construct_yaml_map, + 'tag:yaml.org,2002:python/unicode': cls.construct_yaml_str, + 'tag:yaml.org,2002:seq': cls.construct_yaml_seq, + 'tag:yaml.org,2002:set': cls.construct_yaml_set, + 'tag:yaml.org,2002:str': cls.construct_yaml_str, + 'tag:yaml.org,2002:timestamp': cls.construct_yaml_timestamp, + } + + for tag, constructor in constructors.items(): + cls.add_constructor(tag, constructor) + + +class AnsibleConstructor(AnsibleInstrumentedConstructor): + """Ansible constructor which supports Ansible custom behavior such as `Origin` tagging, as well as Ansible-specific YAML tags.""" + + def __init__(self, origin: Origin, trusted_as_template: bool) -> None: + self._unsafe_depth = 0 # volatile state var used during recursive construction of a value tagged unsafe + + super().__init__(origin=origin, trusted_as_template=trusted_as_template) + + @property + def trusted_as_template(self) -> bool: + return self._trusted_as_template and not self._unsafe_depth + + def construct_yaml_unsafe(self, node): + self._unsafe_depth += 1 + + try: + return self._resolve_and_construct_object(node) + finally: + self._unsafe_depth -= 1 + + def construct_yaml_vault(self, node: Node) -> EncryptedString: + ciphertext = self._resolve_and_construct_object(node) + + if not isinstance(ciphertext, str): + raise AnsibleConstructorError(problem=f"the {node.tag!r} tag requires a string value", problem_mark=node.start_mark) + + encrypted_string = AnsibleTagHelper.tag_copy(ciphertext, EncryptedString(ciphertext=AnsibleTagHelper.untag(ciphertext))) + + return encrypted_string + + def construct_yaml_vault_encrypted(self, node: Node) -> EncryptedString: + origin = self._node_position_info(node) + display.deprecated( + msg='Use of the YAML `!vault-encrypted` tag is deprecated.', + version='2.23', + obj=origin, + help_text='Use the `!vault` tag instead.', + ) + + return self.construct_yaml_vault(node) + + @classmethod + def _register_constructors(cls) -> None: + super()._register_constructors() + + constructors: dict[str, t.Callable] = { + '!unsafe': cls.construct_yaml_unsafe, + '!vault': cls.construct_yaml_vault, + '!vault-encrypted': cls.construct_yaml_vault_encrypted, + } + + for tag, constructor in constructors.items(): + cls.add_constructor(tag, constructor) diff --git a/lib/ansible/_internal/_yaml/_dumper.py b/lib/ansible/_internal/_yaml/_dumper.py new file mode 100644 index 00000000000..dc54ae8ee3a --- /dev/null +++ b/lib/ansible/_internal/_yaml/_dumper.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import abc +import collections.abc as c +import typing as t + +from yaml.representer import SafeRepresenter + +from ansible.module_utils._internal._datatag import AnsibleTaggedObject, Tripwire, AnsibleTagHelper +from ansible.parsing.vault import VaultHelper +from ansible.module_utils.common.yaml import HAS_LIBYAML + +if HAS_LIBYAML: + from yaml.cyaml import CSafeDumper as SafeDumper +else: + from yaml import SafeDumper # type: ignore[assignment] + + +class _BaseDumper(SafeDumper, metaclass=abc.ABCMeta): + """Base class for Ansible YAML dumpers.""" + + @classmethod + @abc.abstractmethod + def _register_representers(cls) -> None: + """Method used to register representers to derived types during class initialization.""" + + def __init_subclass__(cls, **kwargs) -> None: + """Initialization for derived types.""" + cls._register_representers() + + +class AnsibleDumper(_BaseDumper): + """A simple stub class that allows us to add representers for our custom types.""" + + # DTFIX-RELEASE: need a better way to handle serialization controls during YAML dumping + def __init__(self, *args, dump_vault_tags: bool | None = None, **kwargs): + super().__init__(*args, **kwargs) + + self._dump_vault_tags = dump_vault_tags + + @classmethod + def _register_representers(cls) -> None: + cls.add_multi_representer(AnsibleTaggedObject, cls.represent_ansible_tagged_object) + cls.add_multi_representer(Tripwire, cls.represent_tripwire) + cls.add_multi_representer(c.Mapping, SafeRepresenter.represent_dict) + cls.add_multi_representer(c.Sequence, SafeRepresenter.represent_list) + + def represent_ansible_tagged_object(self, data): + if self._dump_vault_tags is not False and (ciphertext := VaultHelper.get_ciphertext(data, with_tags=False)): + # deprecated: description='enable the deprecation warning below' core_version='2.23' + # if self._dump_vault_tags is None: + # Display().deprecated( + # msg="Implicit YAML dumping of vaulted value ciphertext is deprecated. Set `dump_vault_tags` to explicitly specify the desired behavior", + # version="2.27", + # ) + + return self.represent_scalar('!vault', ciphertext, style='|') + + return self.represent_data(AnsibleTagHelper.as_native_type(data)) # automatically decrypts encrypted strings + + def represent_tripwire(self, data: Tripwire) -> t.NoReturn: + data.trip() diff --git a/lib/ansible/_internal/_yaml/_errors.py b/lib/ansible/_internal/_yaml/_errors.py new file mode 100644 index 00000000000..75acdb7a30c --- /dev/null +++ b/lib/ansible/_internal/_yaml/_errors.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import re + +import typing as t + +from yaml import MarkedYAMLError +from yaml.constructor import ConstructorError + +from ansible._internal._errors import _utils +from ansible.errors import AnsibleParserError +from ansible._internal._datatag._tags import Origin + + +class AnsibleConstructorError(ConstructorError): + """Ansible-specific ConstructorError used to bypass exception analysis during wrapping in AnsibleYAMLParserError.""" + + +class AnsibleYAMLParserError(AnsibleParserError): + """YAML-specific parsing failure wrapping an exception raised by the YAML parser.""" + + _default_message = 'YAML parsing failed.' + + _include_cause_message = False # hide the underlying cause message, it's included by `handle_exception` as needed + + _formatted_source_context_value: str | None = None + + @property + def _formatted_source_context(self) -> str | None: + return self._formatted_source_context_value + + @classmethod + def handle_exception(cls, exception: Exception, origin: Origin) -> t.NoReturn: + if isinstance(exception, MarkedYAMLError): + origin = origin.replace(line_num=exception.problem_mark.line + 1, col_num=exception.problem_mark.column + 1) + + source_context = _utils.SourceContext.from_origin(origin) + + target_line = source_context.target_line or '' # for these cases, we don't need to distinguish between None and empty string + + message: str | None = None + help_text = None + + # FIXME: Do all this by walking the parsed YAML doc stream. Using regexes is a dead-end; YAML's just too flexible to not have a + # raft of false-positives and corner cases. If we directly consume either the YAML parse stream or override the YAML composer, we can + # better catch these things without worrying about duplicating YAML's scalar parsing logic around quoting/escaping. At first, we can + # replace the regex logic below with tiny special-purpose parse consumers to catch specific issues, but ideally, we could do a lot of this + # inline with the actual doc parse, since our rules are a lot more strict than YAML's (eg, no support for non-scalar keys), and a lot of the + # problem cases where that comes into play are around expression quoting and Jinja {{ syntax looking like weird YAML values we don't support. + # Some common examples, where -> is "what YAML actually sees": + # foo: {{ bar }} -> {"foo": {{"bar": None}: None}} - a mapping with a mapping as its key (legal YAML, but not legal Python/Ansible) + # + # - copy: src=foo.txt # kv syntax (kv could be on following line(s), too- implicit multi-line block scalar) + # dest: bar.txt # orphaned mapping, since the value of `copy` is the scalar "src=foo.txt" + # + # - msg == "Error: 'dude' was not found" # unquoted scalar has a : in it -> {'msg == "Error"': 'dude'} [ was not found" ] is garbage orphan scalar + + # noinspection PyUnboundLocalVariable + if not isinstance(exception, MarkedYAMLError): + pass # unexpected exception, don't use special analysis of exception + + elif isinstance(exception, AnsibleConstructorError): + pass # raised internally by ansible code, don't use special analysis of exception + + # Check for tabs. + # There may be cases where there is a valid tab in a line that has other errors. + # That's OK, users should "fix" their tab usage anyway -- at which point later error handling logic will hopefully find the real issue. + elif (tab_idx := target_line.find('\t')) >= 0: + source_context = _utils.SourceContext.from_origin(origin.replace(col_num=tab_idx + 1)) + message = "Tabs are usually invalid in YAML." + + # Check for unquoted templates. + elif match := re.search(r'^\s*(?:-\s+)*(?:[\w\s]+:\s+)?(?P\{\{.*}})', target_line): + source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1)) + message = 'This may be an issue with missing quotes around a template block.' + # FIXME: Use the captured value to show the actual fix required. + help_text = """ +For example: + + raw: {{ some_var }} + +Should be: + + raw: "{{ some_var }}" +""" + + # Check for common unquoted colon mistakes. + elif ( + # ignore lines starting with only whitespace and a colon + not target_line.lstrip().startswith(':') + # find the value after list/dict preamble + and (value_match := re.search(r'^\s*(?:-\s+)*(?:[\w\s\[\]{}]+:\s+)?(?P.*)$', target_line)) + # ignore properly quoted values + and (target_fragment := _replace_quoted_value(value_match.group('value'))) + # look for an unquoted colon in the value + and (colon_match := re.search(r':($| )', target_fragment)) + ): + source_context = _utils.SourceContext.from_origin(origin.replace(col_num=value_match.start('value') + colon_match.start() + 1)) + message = 'Colons in unquoted values must be followed by a non-space character.' + # FIXME: Use the captured value to show the actual fix required. + help_text = """ +For example: + + raw: echo 'name: ansible' + +Should be: + + raw: "echo 'name: ansible'" +""" + + # Check for common quoting mistakes. + elif match := re.search(r'^\s*(?:-\s+)*(?:[\w\s]+:\s+)?(?P[\"\'].*?\s*)$', target_line): + suspected_value = match.group('value') + first, last = suspected_value[0], suspected_value[-1] + + if first != last: # "foo" in bar + source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1)) + message = 'Values starting with a quote must end with the same quote.' + # FIXME: Use the captured value to show the actual fix required, and use that same logic to improve the origin further. + help_text = """ +For example: + + raw: "foo" in bar + +Should be: + + raw: '"foo" in bar' +""" + elif first == last and target_line.count(first) > 2: # "foo" and "bar" + source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1)) + message = 'Values starting with a quote must end with the same quote, and not contain that quote.' + # FIXME: Use the captured value to show the actual fix required, and use that same logic to improve the origin further. + help_text = """ +For example: + + raw: "foo" in "bar" + +Should be: + + raw: '"foo" in "bar"' +""" + + if not message: + if isinstance(exception, MarkedYAMLError): + # marked YAML error, pull out the useful messages while omitting the noise + message = ' '.join(filter(None, (exception.context, exception.problem, exception.note))) + message = message.strip() + message = f'{message[0].upper()}{message[1:]}' + + if not message.endswith('.'): + message += '.' + else: + # unexpected error, use the exception message (normally hidden by overriding include_cause_message) + message = str(exception) + + message = re.sub(r'\s+', ' ', message).strip() + + error = cls(message, obj=source_context.origin) + error._formatted_source_context_value = str(source_context) + error._help_text = help_text + + raise error from exception + + +def _replace_quoted_value(value: str, replacement='.') -> str: + return re.sub(r"""^\s*('[^']*'|"[^"]*")\s*$""", lambda match: replacement * len(match.group(0)), value) diff --git a/lib/ansible/_internal/_yaml/_loader.py b/lib/ansible/_internal/_yaml/_loader.py new file mode 100644 index 00000000000..fa14006c0f8 --- /dev/null +++ b/lib/ansible/_internal/_yaml/_loader.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import io as _io + +from yaml.resolver import Resolver + +from ansible.module_utils._internal._datatag import AnsibleTagHelper +from ansible.module_utils.common.yaml import HAS_LIBYAML +from ansible._internal._datatag import _tags + +from ._constructor import AnsibleConstructor, AnsibleInstrumentedConstructor + +if HAS_LIBYAML: + from yaml.cyaml import CParser + + class _YamlParser(CParser): + def __init__(self, stream: str | bytes | _io.IOBase) -> None: + if isinstance(stream, (str, bytes)): + stream = AnsibleTagHelper.untag(stream) # PyYAML + libyaml barfs on str/bytes subclasses + + CParser.__init__(self, stream) + + self.name = getattr(stream, 'name', None) # provide feature parity with the Python implementation (yaml.reader.Reader provides name) + +else: + from yaml.composer import Composer + from yaml.reader import Reader + from yaml.scanner import Scanner + from yaml.parser import Parser + + class _YamlParser(Reader, Scanner, Parser, Composer): # type: ignore[no-redef] + def __init__(self, stream: str | bytes | _io.IOBase) -> None: + Reader.__init__(self, stream) + Scanner.__init__(self) + Parser.__init__(self) + Composer.__init__(self) + + +class AnsibleInstrumentedLoader(_YamlParser, AnsibleInstrumentedConstructor, Resolver): + """Ansible YAML loader which supports Ansible custom behavior such as `Origin` tagging, but no Ansible-specific YAML tags.""" + + def __init__(self, stream: str | bytes | _io.IOBase) -> None: + _YamlParser.__init__(self, stream) + + AnsibleInstrumentedConstructor.__init__( + self, + origin=_tags.Origin.get_or_create_tag(stream, self.name), + trusted_as_template=_tags.TrustedAsTemplate.is_tagged_on(stream), + ) + + Resolver.__init__(self) + + +class AnsibleLoader(_YamlParser, AnsibleConstructor, Resolver): + """Ansible loader which supports Ansible custom behavior such as `Origin` tagging, as well as Ansible-specific YAML tags.""" + + def __init__(self, stream: str | bytes | _io.IOBase) -> None: + _YamlParser.__init__(self, stream) + + AnsibleConstructor.__init__( + self, + origin=_tags.Origin.get_or_create_tag(stream, self.name), + trusted_as_template=_tags.TrustedAsTemplate.is_tagged_on(stream), + ) + + Resolver.__init__(self) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/README.md b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/README.md new file mode 100644 index 00000000000..9ec03246d23 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/README.md @@ -0,0 +1,11 @@ +"Protomatter - an unstable substance which every ethical scientist in the galaxy has denounced as dangerously unpredictable." + +"But it was the only way to solve certain problems..." + +This Ansible Collection is embedded within ansible-core. +It contains plugins useful for ansible-core's own integration tests. +They have been made available, completely unsupported, +in case they prove useful for debugging and troubleshooting purposes. + +> CAUTION: This collection is not supported, and may be changed or removed in any version without prior notice. +Use of these plugins outside ansible-core is highly discouraged. diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/action/debug.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/action/debug.py new file mode 100644 index 00000000000..60d7c64ec9c --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/action/debug.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import typing as t + +from ansible.module_utils.common.validation import _check_type_str_no_conversion, _check_type_list_strict +from ansible.plugins.action import ActionBase +from ansible._internal._templating._engine import TemplateEngine +from ansible._internal._templating._marker_behaviors import ReplacingMarkerBehavior + + +class ActionModule(ActionBase): + TRANSFERS_FILES = False + _requires_connection = False + + @classmethod + def finalize_task_arg(cls, name: str, value: t.Any, templar: TemplateEngine, context: t.Any) -> t.Any: + if name == 'expression': + return value + + return super().finalize_task_arg(name, value, templar, context) + + def run(self, tmp=None, task_vars=None): + # accepts a list of literal expressions (no templating), evaluates with no failure on undefined, returns all results + _vr, args = self.validate_argument_spec( + argument_spec=dict( + expression=dict(type=_check_type_list_strict, elements=_check_type_str_no_conversion, required=True), + ), + ) + + with ReplacingMarkerBehavior.warning_context() as replacing_behavior: + templar = self._templar._engine.extend(marker_behavior=replacing_behavior) + + return dict( + _ansible_verbose_always=True, + expression_result=[templar.evaluate_expression(expression) for expression in args['expression']], + ) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/apply_trust.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/apply_trust.py new file mode 100644 index 00000000000..22f8aa43c94 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/apply_trust.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import typing as t + +from ansible._internal._datatag._tags import TrustedAsTemplate + + +def apply_trust(value: object) -> object: + """ + Filter that returns a tagged copy of the input string with TrustedAsTemplate. + Containers and other non-string values are returned unmodified. + """ + return TrustedAsTemplate().tag(value) if isinstance(value, str) else value + + +class FilterModule: + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(apply_trust=apply_trust) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/dump_object.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/dump_object.py new file mode 100644 index 00000000000..9b8a88427c2 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/dump_object.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import dataclasses +import typing as t + + +def dump_object(value: t.Any) -> object: + """Internal filter to convert objects not supported by JSON to types which are.""" + if dataclasses.is_dataclass(value): + return dataclasses.asdict(value) # type: ignore[arg-type] + + return value + + +class FilterModule(object): + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(dump_object=dump_object) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/finalize.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/finalize.py new file mode 100644 index 00000000000..88f847fb9c8 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/finalize.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import typing as t + +from ansible._internal._templating._engine import _finalize_template_result, FinalizeMode + + +def finalize(value: t.Any) -> t.Any: + """Perform an explicit top-level template finalize operation on the supplied value.""" + return _finalize_template_result(value, mode=FinalizeMode.TOP_LEVEL) + + +class FilterModule: + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(finalize=finalize) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/origin.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/origin.py new file mode 100644 index 00000000000..528bb96c626 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/origin.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import typing as t + +from ansible._internal._datatag._tags import Origin + + +def origin(value: object) -> str | None: + """Return the origin of the value, if any, otherwise `None`.""" + origin_tag = Origin.get_tag(value) + + return str(origin_tag) if origin_tag else None + + +class FilterModule: + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(origin=origin) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.py new file mode 100644 index 00000000000..416c391e75c --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import ast + +from ansible.errors import AnsibleTypeError + + +def python_literal_eval(value: object, ignore_errors=False) -> object: + try: + if isinstance(value, str): + return ast.literal_eval(value) + + raise AnsibleTypeError("The `value` to eval must be a string.", obj=value) + except Exception: + if ignore_errors: + return value + + raise + + +class FilterModule(object): + @staticmethod + def filters(): + return dict(python_literal_eval=python_literal_eval) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.yml b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.yml new file mode 100644 index 00000000000..8d20b835c43 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/python_literal_eval.yml @@ -0,0 +1,33 @@ +DOCUMENTATION: + name: python_literal_eval + version_added: "2.19" + short_description: evaluate a Python literal expression string + description: + - Evaluates the input string as a Python literal expression, returning the resulting data structure. + - Previous versions of Ansible applied this behavior to all template results in non-native Jinja mode. + - This filter provides a way to emulate the previous behavior. + notes: + - Directly calls Python's C(ast.literal_eval). + positional: _input + options: + _input: + description: Python literal string expression. + type: str + required: true + ignore_errors: + description: Whether to silently ignore all errors resulting from the literal_eval operation. If true, the input is silently returned unmodified when an error occurs. + type: bool + default: false + +EXAMPLES: | + - name: evaluate an expression comprised only of Python literals + assert: + that: (another_var | ansible._protomatter.python_literal_eval)[1] == 2 # in 2.19 and later, the explicit python_literal_eval emulates the old templating behavior + vars: + another_var: "{{ some_var }}" # in 2.18 and earlier, indirection through templating caused implicit literal_eval, converting the value to a list + some_var: "[1, 2]" # a value that looks like a Python list literal embedded in a string + +RETURN: + _value: + description: Resulting data structure. + type: raw diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/tag_names.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/tag_names.py new file mode 100644 index 00000000000..92525c8d332 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/tag_names.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import typing as t + +from ansible.module_utils._internal._datatag import AnsibleTagHelper + + +def tag_names(value: object) -> list[str]: + """Return a list of tag type names (if any) present on the given object.""" + return sorted(tag_type.__name__ for tag_type in AnsibleTagHelper.tag_types(value)) + + +class FilterModule: + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(tag_names=tag_names) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/true_type.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/true_type.py new file mode 100644 index 00000000000..a07a4d1ddd9 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/true_type.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import typing as t + +from ansible.plugins import accept_args_markers + + +@accept_args_markers +def true_type(obj: object) -> str: + """Internal filter to show the true type name of the given object, not just the base type name like the `debug` filter.""" + return obj.__class__.__name__ + + +class FilterModule(object): + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(true_type=true_type) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/unmask.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/unmask.py new file mode 100644 index 00000000000..8a07bc79393 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/filter/unmask.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import copy +import dataclasses +import typing as t + +from ansible._internal._templating._jinja_common import validate_arg_type +from ansible._internal._templating._lazy_containers import _AnsibleLazyTemplateMixin +from ansible._internal._templating._transform import _type_transform_mapping +from ansible.errors import AnsibleError + + +def unmask(value: object, type_names: str | list[str]) -> object: + """ + Internal filter to suppress automatic type transformation in Jinja (e.g., WarningMessageDetail, DeprecationMessageDetail, ErrorDetail). + Lazy collection caching is in play - the first attempt to access a value in a given lazy container must be with unmasking in place, or the transformed value + will already be cached. + """ + validate_arg_type("type_names", type_names, (str, list)) + + if isinstance(type_names, str): + check_type_names = [type_names] + else: + check_type_names = type_names + + valid_type_names = {key.__name__ for key in _type_transform_mapping} + invalid_type_names = [type_name for type_name in check_type_names if type_name not in valid_type_names] + + if invalid_type_names: + raise AnsibleError(f'Unknown type name(s): {", ".join(invalid_type_names)}', obj=type_names) + + result: object + + if isinstance(value, _AnsibleLazyTemplateMixin): + result = copy.copy(value) + result._lazy_options = dataclasses.replace( + result._lazy_options, + unmask_type_names=result._lazy_options.unmask_type_names | frozenset(check_type_names), + ) + else: + result = value + + return result + + +class FilterModule(object): + @staticmethod + def filters() -> dict[str, t.Callable]: + return dict(unmask=unmask) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.py new file mode 100644 index 00000000000..c4229320963 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from ansible.plugins.lookup import LookupBase + + +class LookupModule(LookupBase): + """Specialized config lookup that applies data transformations on values that config cannot.""" + + def run(self, terms, variables=None, **kwargs): + if not terms or not (config_name := terms[0]): + raise ValueError("config name is required") + + match config_name: + case 'DISPLAY_TRACEBACK': + # since config can't expand this yet, we need the post-processed version + from ansible.module_utils._internal._traceback import traceback_for + + return traceback_for() + # DTFIX-FUTURE: plumb through normal config fallback + case _: + raise ValueError(f"Unknown config name {config_name!r}.") diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.yml b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.yml new file mode 100644 index 00000000000..5aa954617d2 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/lookup/config.yml @@ -0,0 +1,2 @@ +DOCUMENTATION: + name: config diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.py new file mode 100644 index 00000000000..a13b90d4c86 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import typing as t + +from ansible.module_utils._internal import _datatag + + +def tagged(value: t.Any) -> bool: + return bool(_datatag.AnsibleTagHelper.tag_types(value)) + + +class TestModule: + @staticmethod + def tests() -> dict[str, t.Callable]: + return dict(tagged=tagged) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.yml b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.yml new file mode 100644 index 00000000000..921c03a1513 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged.yml @@ -0,0 +1,19 @@ +DOCUMENTATION: + name: tagged + author: Ansible Core + version_added: "2.19" + short_description: does the value have a data tag + description: + - Check if the provided value has a data tag. + options: + _input: + description: A value. + type: raw + +EXAMPLES: | + is_data_tagged: "{{ my_variable is ansible._protomatter.tagged }}" + +RETURN: + _value: + description: Returns C(True) if the value has one or more data tags, otherwise C(False). + type: boolean diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.py b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.py new file mode 100644 index 00000000000..ef59edcab7e --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import typing as t + +from ansible.module_utils._internal import _datatag + + +def tagged_with(value: t.Any, tag_name: str) -> bool: + if tag_type := _datatag._known_tag_type_map.get(tag_name): + return tag_type.is_tagged_on(value) + + raise ValueError(f"Unknown tag name {tag_name!r}.") + + +class TestModule: + @staticmethod + def tests() -> dict[str, t.Callable]: + return dict(tagged_with=tagged_with) diff --git a/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.yml b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.yml new file mode 100644 index 00000000000..f455ae919a9 --- /dev/null +++ b/lib/ansible/_internal/ansible_collections/ansible/_protomatter/plugins/test/tagged_with.yml @@ -0,0 +1,19 @@ +DOCUMENTATION: + name: tagged_with + author: Ansible Core + version_added: "2.19" + short_description: does the value have the specified data tag + description: + - Check if the provided value has the specified data tag. + options: + _input: + description: A value. + type: raw + +EXAMPLES: | + is_data_tagged: "{{ my_variable is ansible._protomatter.tagged_with('Origin') }}" + +RETURN: + _value: + description: Returns C(True) if the value has the specified data tag, otherwise C(False). + type: boolean diff --git a/lib/ansible/cli/__init__.py b/lib/ansible/cli/__init__.py index 5076fd61acb..462393868e0 100644 --- a/lib/ansible/cli/__init__.py +++ b/lib/ansible/cli/__init__.py @@ -77,18 +77,6 @@ def initialize_locale(): initialize_locale() -from importlib.metadata import version -from ansible.module_utils.compat.version import LooseVersion - -# Used for determining if the system is running a new enough Jinja2 version -# and should only restrict on our documented minimum versions -jinja2_version = version('jinja2') -if jinja2_version < LooseVersion('3.1'): - raise SystemExit( - 'ERROR: Ansible requires Jinja2 3.1 or newer on the controller. ' - 'Current version: %s' % jinja2_version - ) - import atexit import errno import getpass @@ -97,17 +85,22 @@ import traceback from abc import ABC, abstractmethod from pathlib import Path +from ansible import _internal # do not remove or defer; ensures controller-specific state is set early + +_internal.setup() + try: from ansible import constants as C from ansible.utils.display import Display display = Display() -except Exception as e: - print('ERROR: %s' % e, file=sys.stderr) +except Exception as ex: + print(f'ERROR: {ex}\n\n{"".join(traceback.format_exception(ex))}', file=sys.stderr) sys.exit(5) + from ansible import context from ansible.cli.arguments import option_helpers as opt_help -from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError +from ansible.errors import AnsibleError, ExitCode from ansible.inventory.manager import InventoryManager from ansible.module_utils.six import string_types from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -115,14 +108,13 @@ from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.file import is_executable from ansible.module_utils.common.process import get_bin_path from ansible.parsing.dataloader import DataLoader -from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret +from ansible.parsing.vault import PromptVaultSecret, get_file_vault_secret, VaultSecretsContext from ansible.plugins.loader import add_all_plugin_dirs, init_plugin_loader from ansible.release import __version__ from ansible.utils._ssh_agent import SshAgentClient from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path from ansible.utils.path import unfrackpath -from ansible.utils.unsafe_proxy import to_unsafe_text from ansible.vars.manager import VariableManager try: @@ -226,6 +218,9 @@ class CLI(ABC): self.parser = None self.callback = callback + self.show_devel_warning() + + def show_devel_warning(self) -> None: if C.DEVEL_WARNING and __version__.endswith('dev0'): display.warning( 'You are running the development version of Ansible. You should only run Ansible from "devel" if ' @@ -297,7 +292,7 @@ class CLI(ABC): @staticmethod def setup_vault_secrets(loader, vault_ids, vault_password_files=None, ask_vault_pass=None, create_new_password=False, - auto_prompt=True): + auto_prompt=True, initialize_context=True): # list of tuples vault_secrets = [] @@ -394,15 +389,14 @@ class CLI(ABC): if last_exception and not found_vault_secret: raise last_exception + if initialize_context: + VaultSecretsContext.initialize(VaultSecretsContext(vault_secrets)) + return vault_secrets @staticmethod - def _get_secret(prompt): - - secret = getpass.getpass(prompt=prompt) - if secret: - secret = to_unsafe_text(secret) - return secret + def _get_secret(prompt: str) -> str: + return getpass.getpass(prompt=prompt) @staticmethod def ask_passwords(): @@ -411,7 +405,6 @@ class CLI(ABC): op = context.CLIARGS sshpass = None becomepass = None - become_prompt = '' become_prompt_method = "BECOME" if C.AGNOSTIC_BECOME_PROMPT else op['become_method'].upper() @@ -433,7 +426,7 @@ class CLI(ABC): except EOFError: pass - return (sshpass, becomepass) + return sshpass, becomepass def validate_conflicts(self, op, runas_opts=False, fork_opts=False): """ check for conflicting options """ @@ -680,10 +673,9 @@ class CLI(ABC): return hosts @staticmethod - def get_password_from_file(pwd_file): - + def get_password_from_file(pwd_file: str) -> str: b_pwd_file = to_bytes(pwd_file) - secret = None + if b_pwd_file == b'-': # ensure its read as bytes secret = sys.stdin.buffer.read() @@ -703,13 +695,13 @@ class CLI(ABC): stdout, stderr = p.communicate() if p.returncode != 0: - raise AnsibleError("The password script %s returned an error (rc=%s): %s" % (pwd_file, p.returncode, stderr)) + raise AnsibleError("The password script %s returned an error (rc=%s): %s" % (pwd_file, p.returncode, to_text(stderr))) secret = stdout else: try: - with open(b_pwd_file, "rb") as f: - secret = f.read().strip() + with open(b_pwd_file, "rb") as password_file: + secret = password_file.read().strip() except (OSError, IOError) as e: raise AnsibleError("Could not read password file %s: %s" % (pwd_file, e)) @@ -718,7 +710,7 @@ class CLI(ABC): if not secret: raise AnsibleError('Empty password was provided from file (%s)' % pwd_file) - return to_unsafe_text(secret) + return to_text(secret) @classmethod def cli_executor(cls, args=None): @@ -739,54 +731,22 @@ class CLI(ABC): else: display.debug("Created the '%s' directory" % ansible_dir) - try: - args = [to_text(a, errors='surrogate_or_strict') for a in args] - except UnicodeError: - display.error('Command line args are not in utf-8, unable to continue. Ansible currently only understands utf-8') - display.display(u"The full traceback was:\n\n%s" % to_text(traceback.format_exc())) - exit_code = 6 - else: - cli = cls(args) - exit_code = cli.run() - - except AnsibleOptionsError as e: - cli.parser.print_help() - display.error(to_text(e), wrap_text=False) - exit_code = 5 - except AnsibleParserError as e: - display.error(to_text(e), wrap_text=False) - exit_code = 4 - # TQM takes care of these, but leaving comment to reserve the exit codes - # except AnsibleHostUnreachable as e: - # display.error(str(e)) - # exit_code = 3 - # except AnsibleHostFailed as e: - # display.error(str(e)) - # exit_code = 2 - except AnsibleError as e: - display.error(to_text(e), wrap_text=False) - exit_code = 1 + cli = cls(args) + exit_code = cli.run() + except AnsibleError as ex: + display.error(ex) + exit_code = ex._exit_code except KeyboardInterrupt: display.error("User interrupted execution") - exit_code = 99 - except Exception as e: - if C.DEFAULT_DEBUG: - # Show raw stacktraces in debug mode, It also allow pdb to - # enter post mortem mode. - raise - have_cli_options = bool(context.CLIARGS) - display.error("Unexpected Exception, this is probably a bug: %s" % to_text(e), wrap_text=False) - if not have_cli_options or have_cli_options and context.CLIARGS['verbosity'] > 2: - log_only = False - if hasattr(e, 'orig_exc'): - display.vvv('\nexception type: %s' % to_text(type(e.orig_exc))) - why = to_text(e.orig_exc) - if to_text(e) != why: - display.vvv('\noriginal msg: %s' % why) - else: - display.display("to see the full traceback, use -vvv") - log_only = True - display.display(u"the full traceback was:\n\n%s" % to_text(traceback.format_exc()), log_only=log_only) - exit_code = 250 + exit_code = ExitCode.KEYBOARD_INTERRUPT + except Exception as ex: + try: + raise AnsibleError("Unexpected Exception, this is probably a bug.") from ex + except AnsibleError as ex2: + # DTFIX-RELEASE: clean this up so we're not hacking the internals- re-wrap in an AnsibleCLIUnhandledError that always shows TB, or? + from ansible.module_utils._internal import _traceback + _traceback._is_traceback_enabled = lambda *_args, **_kwargs: True + display.error(ex2) + exit_code = ExitCode.UNKNOWN_ERROR sys.exit(exit_code) diff --git a/lib/ansible/cli/adhoc.py b/lib/ansible/cli/adhoc.py index 438ad7dd08d..04d4a276037 100755 --- a/lib/ansible/cli/adhoc.py +++ b/lib/ansible/cli/adhoc.py @@ -6,6 +6,8 @@ from __future__ import annotations +import json + # ansible.cli needs to be imported first, to ensure the source bin/* scripts run that code first from ansible.cli import CLI from ansible import constants as C @@ -15,10 +17,11 @@ from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError from ansible.executor.task_queue_manager import TaskQueueManager from ansible.module_utils.common.text.converters import to_text from ansible.parsing.splitter import parse_kv -from ansible.parsing.utils.yaml import from_yaml from ansible.playbook import Playbook from ansible.playbook.play import Play +from ansible._internal._datatag._tags import Origin from ansible.utils.display import Display +from ansible._internal._json._profiles import _legacy display = Display() @@ -78,7 +81,7 @@ class AdHocCLI(CLI): module_args = None if module_args_raw and module_args_raw.startswith('{') and module_args_raw.endswith('}'): try: - module_args = from_yaml(module_args_raw.strip(), json_only=True) + module_args = json.loads(module_args_raw, cls=_legacy.Decoder) except AnsibleParserError: pass @@ -88,6 +91,8 @@ class AdHocCLI(CLI): mytask = {'action': {'module': context.CLIARGS['module_name'], 'args': module_args}, 'timeout': context.CLIARGS['task_timeout']} + mytask = Origin(description=f'').tag(mytask) + # avoid adding to tasks that don't support it, unless set, then give user an error if context.CLIARGS['module_name'] not in C._ACTION_ALL_INCLUDE_ROLE_TASKS and any(frozenset((async_val, poll))): mytask['async_val'] = async_val diff --git a/lib/ansible/cli/arguments/option_helpers.py b/lib/ansible/cli/arguments/option_helpers.py index 18adc16455a..f43d62adb75 100644 --- a/lib/ansible/cli/arguments/option_helpers.py +++ b/lib/ansible/cli/arguments/option_helpers.py @@ -4,12 +4,17 @@ from __future__ import annotations import copy +import dataclasses +import inspect import operator import argparse import os import os.path import sys import time +import typing as t + +import yaml from jinja2 import __version__ as j2_version @@ -20,6 +25,8 @@ from ansible.module_utils.common.yaml import HAS_LIBYAML, yaml_load from ansible.release import __version__ from ansible.utils.path import unfrackpath +from ansible._internal._datatag._tags import TrustedAsTemplate, Origin + # # Special purpose OptionParsers @@ -30,13 +37,115 @@ class SortingHelpFormatter(argparse.HelpFormatter): super(SortingHelpFormatter, self).add_arguments(actions) +@dataclasses.dataclass(frozen=True, kw_only=True) +class DeprecatedArgument: + version: str + """The Ansible version that will remove the deprecated argument.""" + + option: str | None = None + """The specific option string that is deprecated; None applies to all options for this argument.""" + + def is_deprecated(self, option: str) -> bool: + """Return True if the given option is deprecated, otherwise False.""" + return self.option is None or option == self.option + + def check(self, option: str) -> None: + """Display a deprecation warning if the given option is deprecated.""" + if not self.is_deprecated(option): + return + + from ansible.utils.display import Display + + Display().deprecated(f'The {option!r} argument is deprecated.', version=self.version) + + class ArgumentParser(argparse.ArgumentParser): - def add_argument(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: + self.__actions: dict[str | None, type[argparse.Action]] = {} + + super().__init__(*args, **kwargs) + + def register(self, registry_name, value, object): + """Track registration of actions so that they can be resolved later by name, without depending on the internals of ArgumentParser.""" + if registry_name == 'action': + self.__actions[value] = object + + super().register(registry_name, value, object) + + def _patch_argument(self, args: tuple[str, ...], kwargs: dict[str, t.Any]) -> None: + """ + Patch `kwargs` for an `add_argument` call using the given `args` and `kwargs`. + This is used to apply tags to entire categories of CLI arguments. + """ + name = args[0] + action = kwargs.get('action') + resolved_action = self.__actions.get(action, action) # get the action by name, or use as-is (assume it's a subclass of argparse.Action) + action_signature = inspect.signature(resolved_action.__init__) + + if action_signature.parameters.get('type'): + arg_type = kwargs.get('type', str) + + if not callable(arg_type): + raise ValueError(f'Argument {name!r} requires a callable for the {"type"!r} parameter, not {arg_type!r}.') + + wrapped_arg_type = _tagged_type_factory(name, arg_type) + + kwargs.update(type=wrapped_arg_type) + + def _patch_parser(self, parser): + """Patch and return the given parser to intercept the `add_argument` method for further patching.""" + parser_add_argument = parser.add_argument + + def add_argument(*ag_args, **ag_kwargs): + self._patch_argument(ag_args, ag_kwargs) + + parser_add_argument(*ag_args, **ag_kwargs) + + parser.add_argument = add_argument + + return parser + + def add_subparsers(self, *args, **kwargs): + sub = super().add_subparsers(*args, **kwargs) + sub_add_parser = sub.add_parser + + def add_parser(*sub_args, **sub_kwargs): + return self._patch_parser(sub_add_parser(*sub_args, **sub_kwargs)) + + sub.add_parser = add_parser + + return sub + + def add_argument_group(self, *args, **kwargs): + return self._patch_parser(super().add_argument_group(*args, **kwargs)) + + def add_mutually_exclusive_group(self, *args, **kwargs): + return self._patch_parser(super().add_mutually_exclusive_group(*args, **kwargs)) + + def add_argument(self, *args, **kwargs) -> argparse.Action: action = kwargs.get('action') help = kwargs.get('help') if help and action in {'append', 'append_const', 'count', 'extend', PrependListAction}: help = f'{help.rstrip(".")}. This argument may be specified multiple times.' kwargs['help'] = help + + self._patch_argument(args, kwargs) + + deprecated: DeprecatedArgument | None + + if deprecated := kwargs.pop('deprecated', None): + action_type = self.__actions.get(action, action) + + class DeprecatedAction(action_type): # type: ignore[misc, valid-type] + """A wrapper around an action which handles deprecation warnings.""" + + def __call__(self, parser, namespace, values, option_string=None) -> t.Any: + deprecated.check(option_string) + + return super().__call__(parser, namespace, values, option_string) + + kwargs['action'] = DeprecatedAction + return super().add_argument(*args, **kwargs) @@ -182,13 +291,28 @@ def version(prog=None): cpath = "Default w/o overrides" else: cpath = C.DEFAULT_MODULE_PATH + + if HAS_LIBYAML: + libyaml_fragment = "with libyaml" + + # noinspection PyBroadException + try: + from yaml._yaml import get_version_string + + libyaml_fragment += f" v{get_version_string()}" + except Exception: # pylint: disable=broad-except + libyaml_fragment += ", version unknown" + else: + libyaml_fragment = "without libyaml" + result.append(" configured module search path = %s" % cpath) result.append(" ansible python module location = %s" % ':'.join(ansible.__path__)) result.append(" ansible collection location = %s" % ':'.join(C.COLLECTIONS_PATHS)) result.append(" executable location = %s" % sys.argv[0]) result.append(" python version = %s (%s)" % (''.join(sys.version.splitlines()), to_native(sys.executable))) result.append(" jinja version = %s" % j2_version) - result.append(" libyaml = %s" % HAS_LIBYAML) + result.append(f" pyyaml version = {yaml.__version__} ({libyaml_fragment})") + return "\n".join(result) @@ -292,7 +416,8 @@ def add_fork_options(parser): def add_inventory_options(parser): """Add options for commands that utilize inventory""" parser.add_argument('-i', '--inventory', '--inventory-file', dest='inventory', action="append", - help="specify inventory host path or comma separated host list. --inventory-file is deprecated") + help="specify inventory host path or comma separated host list", + deprecated=DeprecatedArgument(version='2.23', option='--inventory-file')) parser.add_argument('--list-hosts', dest='listhosts', action='store_true', help='outputs a list of matching hosts; does not execute anything else') parser.add_argument('-l', '--limit', default=C.DEFAULT_SUBSET, dest='subset', @@ -318,9 +443,9 @@ def add_module_options(parser): def add_output_options(parser): """Add options for commands which can change their output""" parser.add_argument('-o', '--one-line', dest='one_line', action='store_true', - help='condense output') + help='condense output', deprecated=DeprecatedArgument(version='2.23')) parser.add_argument('-t', '--tree', dest='tree', default=None, - help='log output to this directory') + help='log output to this directory', deprecated=DeprecatedArgument(version='2.23')) def add_runas_options(parser): @@ -396,3 +521,25 @@ def add_vault_options(parser): help='ask for vault password') base_group.add_argument('--vault-password-file', '--vault-pass-file', default=[], dest='vault_password_files', help="vault password file", type=unfrack_path(follow=False), action='append') + + +def _tagged_type_factory(name: str, func: t.Callable[[str], object], /) -> t.Callable[[str], object]: + """ + Return a callable that wraps the given function. + The result of the wrapped function will be tagged with Origin. + It will also be tagged with TrustedAsTemplate if it is equal to the original input string. + """ + def tag_value(value: str) -> object: + result = func(value) + + if result is value: + # Values which are not mutated are automatically trusted for templating. + # The `is` reference equality is critically important, as other types may only alter the tags, so object equality is + # not sufficient to prevent them being tagged as trusted when they should not. + result = TrustedAsTemplate().tag(result) + + return Origin(description=f'').tag(result) + + tag_value._name = name # simplify debugging by attaching the argument name to the function + + return tag_value diff --git a/lib/ansible/cli/config.py b/lib/ansible/cli/config.py index a88beb7b1ea..ed42545df47 100755 --- a/lib/ansible/cli/config.py +++ b/lib/ansible/cli/config.py @@ -10,7 +10,6 @@ from ansible.cli import CLI import os import shlex -import subprocess import sys import yaml @@ -24,7 +23,7 @@ from ansible.cli.arguments import option_helpers as opt_help from ansible.config.manager import ConfigManager from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleRequiredOptionError from ansible.module_utils.common.text.converters import to_native, to_text, to_bytes -from ansible.module_utils.common.json import json_dump +from ansible._internal import _json from ansible.module_utils.six import string_types from ansible.parsing.quoting import is_quoted from ansible.parsing.yaml.dumper import AnsibleDumper @@ -178,8 +177,6 @@ class ConfigCLI(CLI): except Exception: if context.CLIARGS['action'] in ['view']: raise - elif context.CLIARGS['action'] in ['edit', 'update']: - display.warning("File does not exist, used empty file: %s" % self.config_file) elif context.CLIARGS['action'] == 'view': raise AnsibleError('Invalid or no config file was supplied') @@ -187,30 +184,6 @@ class ConfigCLI(CLI): # run the requested action context.CLIARGS['func']() - def execute_update(self): - """ - Updates a single setting in the specified ansible.cfg - """ - raise AnsibleError("Option not implemented yet") - - # pylint: disable=unreachable - if context.CLIARGS['setting'] is None: - raise AnsibleOptionsError("update option requires a setting to update") - - (entry, value) = context.CLIARGS['setting'].split('=') - if '.' in entry: - (section, option) = entry.split('.') - else: - section = 'defaults' - option = entry - subprocess.call([ - 'ansible', - '-m', 'ini_file', - 'localhost', - '-c', 'local', - '-a', '"dest=%s section=%s option=%s value=%s backup=yes"' % (self.config_file, section, option, value) - ]) - def execute_view(self): """ Displays the current config file @@ -221,20 +194,6 @@ class ConfigCLI(CLI): except Exception as e: raise AnsibleError("Failed to open config file: %s" % to_native(e)) - def execute_edit(self): - """ - Opens ansible.cfg in the default EDITOR - """ - raise AnsibleError("Option not implemented yet") - - # pylint: disable=unreachable - try: - editor = shlex.split(C.config.get_config_value('EDITOR')) - editor.append(self.config_file) - subprocess.call(editor) - except Exception as e: - raise AnsibleError("Failed to open editor: %s" % to_native(e)) - def _list_plugin_settings(self, ptype, plugins=None): entries = {} loader = getattr(plugin_loader, '%s_loader' % ptype) @@ -302,7 +261,7 @@ class ConfigCLI(CLI): if context.CLIARGS['format'] == 'yaml': output = yaml_dump(config_entries) elif context.CLIARGS['format'] == 'json': - output = json_dump(config_entries) + output = _json.json_dumps_formatted(config_entries) self.pager(to_text(output, errors='surrogate_or_strict')) @@ -495,16 +454,17 @@ class ConfigCLI(CLI): # Add base config = self.config.get_configuration_definitions(ignore_private=True) # convert to settings + settings = {} for setting in config.keys(): v, o = C.config.get_config_value_and_origin(setting, cfile=self.config_file, variables=get_constants()) - config[setting] = { + settings[setting] = { 'name': setting, 'value': v, 'origin': o, 'type': None } - return self._render_settings(config) + return self._render_settings(settings) def _get_plugin_configs(self, ptype, plugins): @@ -659,7 +619,7 @@ class ConfigCLI(CLI): if context.CLIARGS['format'] == 'yaml': text = yaml_dump(output) elif context.CLIARGS['format'] == 'json': - text = json_dump(output) + text = _json.json_dumps_formatted(output) self.pager(to_text(text, errors='surrogate_or_strict')) diff --git a/lib/ansible/cli/console.py b/lib/ansible/cli/console.py index 8ab08c5baab..19a844d5217 100755 --- a/lib/ansible/cli/console.py +++ b/lib/ansible/cli/console.py @@ -29,6 +29,7 @@ from ansible.plugins.list import list_plugins from ansible.plugins.loader import module_loader, fragment_loader from ansible.utils import plugin_docs from ansible.utils.color import stringc +from ansible._internal._datatag._tags import TrustedAsTemplate from ansible.utils.display import Display display = Display() @@ -181,6 +182,8 @@ class ConsoleCLI(CLI, cmd.Cmd): else: module_args = '' + module_args = TrustedAsTemplate().tag(module_args) + if self.callback: cb = self.callback elif C.DEFAULT_LOAD_CALLBACK_PLUGINS and C.DEFAULT_STDOUT_CALLBACK != 'default': @@ -239,11 +242,8 @@ class ConsoleCLI(CLI, cmd.Cmd): except KeyboardInterrupt: display.error('User interrupted execution') return False - except Exception as e: - if self.verbosity >= 3: - import traceback - display.v(traceback.format_exc()) - display.error(to_text(e)) + except Exception as ex: + display.error(ex) return False def emptyline(self): diff --git a/lib/ansible/cli/doc.py b/lib/ansible/cli/doc.py index 6efe0319e5f..4835785deb6 100755 --- a/lib/ansible/cli/doc.py +++ b/lib/ansible/cli/doc.py @@ -15,7 +15,8 @@ import os import os.path import re import textwrap -import traceback + +import yaml import ansible.plugins.loader as plugin_loader @@ -28,12 +29,12 @@ from ansible.collections.list import list_collection_dirs from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError, AnsiblePluginNotFound from ansible.module_utils.common.text.converters import to_native, to_text from ansible.module_utils.common.collections import is_sequence -from ansible.module_utils.common.json import json_dump from ansible.module_utils.common.yaml import yaml_dump from ansible.module_utils.six import string_types from ansible.parsing.plugin_docs import read_docstub -from ansible.parsing.utils.yaml import from_yaml from ansible.parsing.yaml.dumper import AnsibleDumper +from ansible.parsing.yaml.loader import AnsibleLoader +from ansible._internal._yaml._loader import AnsibleInstrumentedLoader from ansible.plugins.list import list_plugins from ansible.plugins.loader import action_loader, fragment_loader from ansible.utils.collection_loader import AnsibleCollectionConfig, AnsibleCollectionRef @@ -41,6 +42,8 @@ from ansible.utils.collection_loader._collection_finder import _get_collection_n from ansible.utils.color import stringc from ansible.utils.display import Display from ansible.utils.plugin_docs import get_plugin_docs, get_docstring, get_versioned_doclink +from ansible.template import trust_as_template +from ansible._internal import _json display = Display() @@ -83,10 +86,9 @@ ref_style = { def jdump(text): try: - display.display(json_dump(text)) - except TypeError as e: - display.vvv(traceback.format_exc()) - raise AnsibleError('We could not convert all the documentation into JSON as there was a conversion issue: %s' % to_native(e)) + display.display(_json.json_dumps_formatted(text)) + except TypeError as ex: + raise AnsibleError('We could not convert all the documentation into JSON as there was a conversion issue.') from ex class RoleMixin(object): @@ -129,11 +131,11 @@ class RoleMixin(object): try: with open(path, 'r') as f: - data = from_yaml(f.read(), file_name=path) + data = yaml.load(trust_as_template(f), Loader=AnsibleLoader) if data is None: data = {} - except (IOError, OSError) as e: - raise AnsibleParserError("Could not read the role '%s' (at %s)" % (role_name, path), orig_exc=e) + except (IOError, OSError) as ex: + raise AnsibleParserError(f"Could not read the role {role_name!r} (at {path}).") from ex return data @@ -697,16 +699,16 @@ class DocCLI(CLI, RoleMixin): display.warning("Skipping role '%s' due to: %s" % (role, role_json[role]['error']), True) continue text += self.get_role_man_text(role, role_json[role]) - except AnsibleParserError as e: + except AnsibleError as ex: # TODO: warn and skip role? - raise AnsibleParserError("Role '%s" % (role), orig_exc=e) + raise AnsibleParserError(f"Error extracting role docs from {role!r}.") from ex # display results DocCLI.pager("\n".join(text)) @staticmethod def _list_keywords(): - return from_yaml(pkgutil.get_data('ansible', 'keyword_desc.yml')) + return yaml.load(pkgutil.get_data('ansible', 'keyword_desc.yml'), Loader=AnsibleInstrumentedLoader) @staticmethod def _get_keywords_docs(keys): @@ -769,10 +771,8 @@ class DocCLI(CLI, RoleMixin): data[key] = kdata - except (AttributeError, KeyError) as e: - display.warning("Skipping Invalid keyword '%s' specified: %s" % (key, to_text(e))) - if display.verbosity >= 3: - display.verbose(traceback.format_exc()) + except (AttributeError, KeyError) as ex: + display.error_as_warning(f'Skipping invalid keyword {key!r}.', ex) return data @@ -820,16 +820,19 @@ class DocCLI(CLI, RoleMixin): except AnsiblePluginNotFound as e: display.warning(to_native(e)) continue - except Exception as e: + except Exception as ex: + msg = "Missing documentation (or could not parse documentation)" + if not fail_on_errors: - plugin_docs[plugin] = {'error': 'Missing documentation or could not parse documentation: %s' % to_native(e)} + plugin_docs[plugin] = {'error': f'{msg}: {ex}.'} continue - display.vvv(traceback.format_exc()) - msg = "%s %s missing documentation (or could not parse documentation): %s\n" % (plugin_type, plugin, to_native(e)) + + msg = f"{plugin_type} {plugin} {msg}" + if fail_ok: - display.warning(msg) + display.warning(f'{msg}: {ex}') else: - raise AnsibleError(msg) + raise AnsibleError(f'{msg}.') from ex if not doc: # The doc section existed but was empty @@ -841,9 +844,9 @@ class DocCLI(CLI, RoleMixin): if not fail_on_errors: # Check whether JSON serialization would break try: - json_dump(docs) - except Exception as e: # pylint:disable=broad-except - plugin_docs[plugin] = {'error': 'Cannot serialize documentation as JSON: %s' % to_native(e)} + _json.json_dumps_formatted(docs) + except Exception as ex: # pylint:disable=broad-except + plugin_docs[plugin] = {'error': f'Cannot serialize documentation as JSON: {ex}'} continue plugin_docs[plugin] = docs @@ -1016,9 +1019,8 @@ class DocCLI(CLI, RoleMixin): try: doc, __, __, __ = get_docstring(filename, fragment_loader, verbose=(context.CLIARGS['verbosity'] > 0), collection_name=collection_name, plugin_type=plugin_type) - except Exception: - display.vvv(traceback.format_exc()) - raise AnsibleError("%s %s at %s has a documentation formatting error or is missing documentation." % (plugin_type, plugin_name, filename)) + except Exception as ex: + raise AnsibleError(f"{plugin_type} {plugin_name} at {filename!r} has a documentation formatting error or is missing documentation.") from ex if doc is None: # Removed plugins don't have any documentation @@ -1094,9 +1096,8 @@ class DocCLI(CLI, RoleMixin): try: text = DocCLI.get_man_text(doc, collection_name, plugin_type) - except Exception as e: - display.vvv(traceback.format_exc()) - raise AnsibleError("Unable to retrieve documentation from '%s'" % (plugin), orig_exc=e) + except Exception as ex: + raise AnsibleError(f"Unable to retrieve documentation from {plugin!r}.") from ex return text @@ -1508,8 +1509,8 @@ class DocCLI(CLI, RoleMixin): else: try: text.append(yaml_dump(doc.pop('plainexamples'), indent=2, default_flow_style=False)) - except Exception as e: - raise AnsibleParserError("Unable to parse examples section", orig_exc=e) + except Exception as ex: + raise AnsibleParserError("Unable to parse examples section.") from ex if doc.get('returndocs', False): text.append('') diff --git a/lib/ansible/cli/galaxy.py b/lib/ansible/cli/galaxy.py index 76e566f4a5c..6c8c749f9b4 100755 --- a/lib/ansible/cli/galaxy.py +++ b/lib/ansible/cli/galaxy.py @@ -53,10 +53,12 @@ from ansible.module_utils.ansible_release import __version__ as ansible_version from ansible.module_utils.common.collections import is_iterable from ansible.module_utils.common.yaml import yaml_dump, yaml_load from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text +from ansible._internal._datatag._tags import TrustedAsTemplate from ansible.module_utils import six from ansible.parsing.dataloader import DataLoader from ansible.playbook.role.requirement import RoleRequirement -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine +from ansible.template import trust_as_template from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.display import Display from ansible.utils.plugin_docs import get_versioned_doclink @@ -915,8 +917,8 @@ class GalaxyCLI(CLI): @staticmethod def _get_skeleton_galaxy_yml(template_path, inject_data): - with open(to_bytes(template_path, errors='surrogate_or_strict'), 'rb') as template_obj: - meta_template = to_text(template_obj.read(), errors='surrogate_or_strict') + with open(to_bytes(template_path, errors='surrogate_or_strict'), 'r') as template_obj: + meta_template = TrustedAsTemplate().tag(to_text(template_obj.read(), errors='surrogate_or_strict')) galaxy_meta = get_collections_galaxy_meta_info() @@ -952,7 +954,7 @@ class GalaxyCLI(CLI): return textwrap.fill(v, width=117, initial_indent="# ", subsequent_indent="# ", break_on_hyphens=False) loader = DataLoader() - templar = Templar(loader, variables={'required_config': required_config, 'optional_config': optional_config}) + templar = TemplateEngine(loader, variables={'required_config': required_config, 'optional_config': optional_config}) templar.environment.filters['comment_ify'] = comment_ify meta_value = templar.template(meta_template) @@ -1154,7 +1156,7 @@ class GalaxyCLI(CLI): loader = DataLoader() inject_data.update(load_extra_vars(loader)) - templar = Templar(loader, variables=inject_data) + templar = TemplateEngine(loader, variables=inject_data) # create role directory if not os.path.exists(b_obj_path): @@ -1196,7 +1198,7 @@ class GalaxyCLI(CLI): elif ext == ".j2" and not in_templates_dir: src_template = os.path.join(root, f) dest_file = os.path.join(obj_path, rel_root, filename) - template_data = to_text(loader._get_file_contents(src_template)[0], errors='surrogate_or_strict') + template_data = trust_as_template(loader.get_text_file_contents(src_template)) try: b_rendered = to_bytes(templar.template(template_data), errors='surrogate_or_strict') except AnsibleError as e: @@ -1764,6 +1766,8 @@ class GalaxyCLI(CLI): return 0 + _task_check_delay_sec = 10 # allows unit test override + def execute_import(self): """ used to import a role into Ansible Galaxy """ @@ -1817,7 +1821,7 @@ class GalaxyCLI(CLI): rc = ['SUCCESS', 'FAILED'].index(state) finished = True else: - time.sleep(10) + time.sleep(self._task_check_delay_sec) return rc diff --git a/lib/ansible/cli/inventory.py b/lib/ansible/cli/inventory.py index 5d99d24ed68..8033b2e0f95 100755 --- a/lib/ansible/cli/inventory.py +++ b/lib/ansible/cli/inventory.py @@ -9,15 +9,19 @@ from __future__ import annotations # ansible.cli needs to be imported first, to ensure the source bin/* scripts run that code first from ansible.cli import CLI +import json import sys +import typing as t import argparse +import functools from ansible import constants as C from ansible import context from ansible.cli.arguments import option_helpers as opt_help -from ansible.errors import AnsibleError, AnsibleOptionsError +from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleRuntimeError from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text +from ansible._internal._json._profiles import _inventory_legacy from ansible.utils.vars import combine_vars from ansible.utils.display import Display from ansible.vars.plugins import get_vars_from_inventory_sources, get_vars_from_path @@ -156,34 +160,17 @@ class InventoryCLI(CLI): @staticmethod def dump(stuff): - if context.CLIARGS['yaml']: import yaml from ansible.parsing.yaml.dumper import AnsibleDumper - results = to_text(yaml.dump(stuff, Dumper=AnsibleDumper, default_flow_style=False, allow_unicode=True)) + + # DTFIX-RELEASE: need shared infra to smuggle custom kwargs to dumpers, since yaml.dump cannot (as of PyYAML 6.0.1) + dumper = functools.partial(AnsibleDumper, dump_vault_tags=True) + results = to_text(yaml.dump(stuff, Dumper=dumper, default_flow_style=False, allow_unicode=True)) elif context.CLIARGS['toml']: - from ansible.plugins.inventory.toml import toml_dumps - try: - results = toml_dumps(stuff) - except TypeError as e: - raise AnsibleError( - 'The source inventory contains a value that cannot be represented in TOML: %s' % e - ) - except KeyError as e: - raise AnsibleError( - 'The source inventory contains a non-string key (%s) which cannot be represented in TOML. ' - 'The specified key will need to be converted to a string. Be aware that if your playbooks ' - 'expect this key to be non-string, your playbooks will need to be modified to support this ' - 'change.' % e.args[0] - ) + results = toml_dumps(stuff) else: - import json - from ansible.parsing.ajson import AnsibleJSONEncoder - try: - results = json.dumps(stuff, cls=AnsibleJSONEncoder, sort_keys=True, indent=4, preprocess_unsafe=True, ensure_ascii=False) - except TypeError as e: - results = json.dumps(stuff, cls=AnsibleJSONEncoder, sort_keys=False, indent=4, preprocess_unsafe=True, ensure_ascii=False) - display.warning("Could not sort JSON output due to issues while sorting keys: %s" % to_native(e)) + results = json.dumps(stuff, cls=_inventory_legacy.Encoder, sort_keys=True, indent=4) return results @@ -306,7 +293,11 @@ class InventoryCLI(CLI): results = format_group(top, frozenset(h.name for h in hosts)) # populate meta - results['_meta'] = {'hostvars': {}} + results['_meta'] = { + 'hostvars': {}, + 'profile': _inventory_legacy.Encoder.profile_name, + } + for host in hosts: hvars = self._get_host_variables(host) if hvars: @@ -409,6 +400,17 @@ class InventoryCLI(CLI): return results +def toml_dumps(data: t.Any) -> str: + try: + from tomli_w import dumps as _tomli_w_dumps + except ImportError: + pass + else: + return _tomli_w_dumps(data) + + raise AnsibleRuntimeError('The Python library "tomli-w" is required when using the TOML output format.') + + def main(args=None): InventoryCLI.cli_executor(args) diff --git a/lib/ansible/cli/scripts/ansible_connection_cli_stub.py b/lib/ansible/cli/scripts/ansible_connection_cli_stub.py index 0c8baa9871f..adaaedc669d 100644 --- a/lib/ansible/cli/scripts/ansible_connection_cli_stub.py +++ b/lib/ansible/cli/scripts/ansible_connection_cli_stub.py @@ -21,7 +21,7 @@ from ansible.cli.arguments import option_helpers as opt_help from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.connection import Connection, ConnectionError, send_data, recv_data from ansible.module_utils.service import fork_process -from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder +from ansible.module_utils._internal._json._profiles import _tagless from ansible.playbook.play_context import PlayContext from ansible.plugins.loader import connection_loader, init_plugin_loader from ansible.utils.path import unfrackpath, makedirs_safe @@ -110,7 +110,7 @@ class ConnectionProcess(object): result['exception'] = traceback.format_exc() finally: result['messages'] = messages - self.fd.write(json.dumps(result, cls=AnsibleJSONEncoder)) + self.fd.write(json.dumps(result, cls=_tagless.Encoder)) self.fd.close() def run(self): @@ -292,7 +292,7 @@ def main(args=None): else: os.close(w) rfd = os.fdopen(r, 'r') - data = json.loads(rfd.read(), cls=AnsibleJSONDecoder) + data = json.loads(rfd.read(), cls=_tagless.Decoder) messages.extend(data.pop('messages')) result.update(data) @@ -330,10 +330,10 @@ def main(args=None): sys.stdout = saved_stdout if 'exception' in result: rc = 1 - sys.stderr.write(json.dumps(result, cls=AnsibleJSONEncoder)) + sys.stderr.write(json.dumps(result, cls=_tagless.Encoder)) else: rc = 0 - sys.stdout.write(json.dumps(result, cls=AnsibleJSONEncoder)) + sys.stdout.write(json.dumps(result, cls=_tagless.Encoder)) sys.exit(rc) diff --git a/lib/ansible/cli/vault.py b/lib/ansible/cli/vault.py index 898548e62b4..6e3b56d002a 100755 --- a/lib/ansible/cli/vault.py +++ b/lib/ansible/cli/vault.py @@ -228,6 +228,7 @@ class VaultCLI(CLI): vault_ids=new_vault_ids, vault_password_files=new_vault_password_files, ask_vault_pass=context.CLIARGS['ask_vault_pass'], + initialize_context=False, create_new_password=True) if not new_vault_secrets: @@ -259,7 +260,7 @@ class VaultCLI(CLI): display.display("Reading plaintext input from stdin", stderr=True) for f in context.CLIARGS['args'] or ['-']: - # Fixme: use the correct vau + # FIXME: use the correct vau self.editor.encrypt_file(f, self.encrypt_secret, vault_id=self.encrypt_vault_id, output_file=context.CLIARGS['output_file']) diff --git a/lib/ansible/config/base.yml b/lib/ansible/config/base.yml index 414a817d312..b3cd67607fa 100644 --- a/lib/ansible/config/base.yml +++ b/lib/ansible/config/base.yml @@ -9,6 +9,38 @@ _ANSIBLE_CONNECTION_PATH: - For internal use only. type: path version_added: "2.18" +ALLOW_BROKEN_CONDITIONALS: + # This config option will be deprecated once it no longer has any effect (2.23). + name: Allow broken conditionals + default: false + description: + - When enabled, this option allows conditionals with non-boolean results to be used. + - A deprecation warning will be emitted in these cases. + - By default, non-boolean conditionals result in an error. + - Such results often indicate unintentional use of templates where they are not supported, resulting in a conditional that is always true. + - When this option is enabled, conditional expressions which are a literal ``None`` or empty string will evaluate as true for backwards compatibility. + env: [{name: ANSIBLE_ALLOW_BROKEN_CONDITIONALS}] + ini: + - {key: allow_broken_conditionals, section: defaults} + type: boolean + version_added: "2.19" +ALLOW_EMBEDDED_TEMPLATES: + name: Allow embedded templates + default: true + description: + - When enabled, this option allows embedded templates to be used for specific backward compatibility scenarios. + - A deprecation warning will be emitted in these cases. + - First, conditionals (for example, ``failed_when``, ``until``, ``assert.that``) fully enclosed in template delimiters. + - "Second, string constants in conditionals (for example, ``when: some_var == '{{ some_other_var }}'``)." + - Finally, positional arguments to lookups (for example, ``lookup('pipe', 'echo {{ some_var }}')``). + - This feature is deprecated, since embedded templates are unnecessary in these cases. + - When disabled, use of embedded templates will result in an error. + - A future release will disable this feature by default. + env: [{name: ANSIBLE_ALLOW_EMBEDDED_TEMPLATES}] + ini: + - {key: allow_embedded_templates, section: defaults} + type: boolean + version_added: "2.19" ANSIBLE_HOME: name: The Ansible home path description: @@ -160,38 +192,50 @@ AGNOSTIC_BECOME_PROMPT: yaml: {key: privilege_escalation.agnostic_become_prompt} version_added: "2.5" CACHE_PLUGIN: - name: Persistent Cache plugin + name: Persistent Fact Cache plugin default: memory - description: Chooses which cache plugin to use, the default 'memory' is ephemeral. + description: Chooses which fact cache plugin to use. By default, no cache is used and facts do not persist between runs. env: [{name: ANSIBLE_CACHE_PLUGIN}] ini: - {key: fact_caching, section: defaults} yaml: {key: facts.cache.plugin} CACHE_PLUGIN_CONNECTION: - name: Cache Plugin URI + name: Fact Cache Plugin URI default: ~ - description: Defines connection or path information for the cache plugin. + description: Defines connection or path information for the fact cache plugin. env: [{name: ANSIBLE_CACHE_PLUGIN_CONNECTION}] ini: - {key: fact_caching_connection, section: defaults} yaml: {key: facts.cache.uri} CACHE_PLUGIN_PREFIX: - name: Cache Plugin table prefix + name: Fact Cache Plugin table prefix default: ansible_facts - description: Prefix to use for cache plugin files/tables. + description: Prefix to use for fact cache plugin files/tables. env: [{name: ANSIBLE_CACHE_PLUGIN_PREFIX}] ini: - {key: fact_caching_prefix, section: defaults} yaml: {key: facts.cache.prefix} CACHE_PLUGIN_TIMEOUT: - name: Cache Plugin expiration timeout + name: Fact Cache Plugin expiration timeout default: 86400 - description: Expiration timeout for the cache plugin data. + description: Expiration timeout for the fact cache plugin data. env: [{name: ANSIBLE_CACHE_PLUGIN_TIMEOUT}] ini: - {key: fact_caching_timeout, section: defaults} type: integer yaml: {key: facts.cache.timeout} +_CALLBACK_DISPATCH_ERROR_BEHAVIOR: + name: Callback dispatch error behavior + default: warn + description: + - Action to take when a callback dispatch results in an error. + type: choices + choices: &choices_ignore_warn_fail + - ignore + - warn + - fail + env: [ { name: _ANSIBLE_CALLBACK_DISPATCH_ERROR_BEHAVIOR } ] + version_added: '2.19' COLLECTIONS_SCAN_SYS_PATH: name: Scan PYTHONPATH for installed collections description: A boolean to enable or disable scanning the sys.path for installed collections. @@ -496,6 +540,10 @@ DEFAULT_ALLOW_UNSAFE_LOOKUPS: - {key: allow_unsafe_lookups, section: defaults} type: boolean version_added: "2.2.3" + deprecated: + why: This option is no longer used in the Ansible Core code base. + version: "2.23" + alternatives: Lookup plugins are responsible for tagging strings containing templates to allow evaluation as a template. DEFAULT_ASK_PASS: name: Ask for the login password default: False @@ -755,15 +803,20 @@ DEFAULT_INVENTORY_PLUGIN_PATH: DEFAULT_JINJA2_EXTENSIONS: name: Enabled Jinja2 extensions default: [] + type: list description: - This is a developer-specific feature that allows enabling additional Jinja2 extensions. - "See the Jinja2 documentation for details. If you do not know what these do, you probably don't need to change this setting :)" env: [{name: ANSIBLE_JINJA2_EXTENSIONS}] ini: - {key: jinja2_extensions, section: defaults} + deprecated: + why: Jinja2 extensions have been deprecated + version: "2.23" + alternatives: Ansible-supported Jinja plugins (tests, filters, lookups) DEFAULT_JINJA2_NATIVE: name: Use Jinja2's NativeEnvironment for templating - default: False + default: True description: This option preserves variable types during template operations. env: [{name: ANSIBLE_JINJA2_NATIVE}] ini: @@ -771,6 +824,10 @@ DEFAULT_JINJA2_NATIVE: type: boolean yaml: {key: jinja2_native} version_added: 2.7 + deprecated: + why: This option is no longer used in the Ansible Core code base. + version: "2.23" + alternatives: Jinja2 native mode is now the default and only option. DEFAULT_KEEP_REMOTE_FILES: name: Keep remote files default: False @@ -930,6 +987,10 @@ DEFAULT_NULL_REPRESENTATION: ini: - {key: null_representation, section: defaults} type: raw + deprecated: + why: This option is no longer used in the Ansible Core code base. + version: "2.23" + alternatives: There is no alternative at the moment. A different mechanism would have to be implemented in the current code base. DEFAULT_POLL_INTERVAL: name: Async poll interval default: 15 @@ -1129,6 +1190,10 @@ DEFAULT_UNDEFINED_VAR_BEHAVIOR: ini: - {key: error_on_undefined_vars, section: defaults} type: boolean + deprecated: + why: This option is no longer used in the Ansible Core code base. + version: "2.23" + alternatives: There is no alternative at the moment. A different mechanism would have to be implemented in the current code base. DEFAULT_VARS_PLUGIN_PATH: name: Vars Plugins Path default: '{{ ANSIBLE_HOME ~ "/plugins/vars:/usr/share/ansible/plugins/vars" }}' @@ -1213,6 +1278,9 @@ DEPRECATION_WARNINGS: ini: - {key: deprecation_warnings, section: defaults} type: boolean + vars: + - name: ansible_deprecation_warnings + version_added: '2.19' DEVEL_WARNING: name: Running devel warning default: True @@ -1266,6 +1334,22 @@ DISPLAY_SKIPPED_HOSTS: ini: - {key: display_skipped_hosts, section: defaults} type: boolean +DISPLAY_TRACEBACK: + name: Control traceback display + default: never + description: When to include tracebacks in extended error messages + env: + - name: ANSIBLE_DISPLAY_TRACEBACK + ini: + - {key: display_traceback, section: defaults} + type: list + choices: + - error + - warning + - deprecated + - always + - never + version_added: "2.19" DOCSITE_ROOT_URL: name: Root docsite URL default: https://docs.ansible.com/ansible-core/ @@ -1916,6 +2000,10 @@ STRING_TYPE_FILTERS: ini: - {key: dont_type_filters, section: jinja2} type: list + deprecated: + why: This option has no effect. + version: "2.23" + alternatives: None; native types returned from filters are always preserved. SYSTEM_WARNINGS: name: System warnings default: True @@ -1968,6 +2056,39 @@ TASK_TIMEOUT: - {key: task_timeout, section: defaults} type: integer version_added: '2.10' +_TEMPLAR_UNKNOWN_TYPE_CONVERSION: + name: Templar unknown type conversion behavior + default: warn + description: + - Action to take when an unknown type is converted for variable storage during template finalization. + - This setting has no effect on the inability to store unsupported variable types as the result of templating. + - Experimental diagnostic feature, subject to change. + type: choices + choices: *choices_ignore_warn_fail + env: [{name: _ANSIBLE_TEMPLAR_UNKNOWN_TYPE_CONVERSION}] + version_added: '2.19' +_TEMPLAR_UNKNOWN_TYPE_ENCOUNTERED: + name: Templar unknown type encountered behavior + default: ignore + description: + - Action to take when an unknown type is encountered inside a template pipeline. + - Experimental diagnostic feature, subject to change. + type: choices + choices: *choices_ignore_warn_fail + env: [{name: _ANSIBLE_TEMPLAR_UNKNOWN_TYPE_ENCOUNTERED}] + version_added: '2.19' +_TEMPLAR_UNTRUSTED_TEMPLATE_BEHAVIOR: + name: Templar untrusted template behavior + default: ignore + description: + - Action to take when processing of an untrusted template is skipped. + - For `ignore` or `warn`, the input template string is returned as-is. + - This setting has no effect on expressions. + - Experimental diagnostic feature, subject to change. + type: choices + choices: *choices_ignore_warn_fail + env: [{name: _ANSIBLE_TEMPLAR_UNTRUSTED_TEMPLATE_BEHAVIOR}] + version_added: '2.19' WORKER_SHUTDOWN_POLL_COUNT: name: Worker Shutdown Poll Count default: 0 @@ -2030,6 +2151,12 @@ WIN_ASYNC_STARTUP_TIMEOUT: vars: - {name: ansible_win_async_startup_timeout} version_added: '2.10' +WRAP_STDERR: + description: Control line-wrapping behavior on console warnings and errors from default output callbacks (eases pattern-based output testing) + env: [{name: ANSIBLE_WRAP_STDERR}] + default: false + type: bool + version_added: "2.19" YAML_FILENAME_EXTENSIONS: name: Valid YAML extensions default: [".yml", ".yaml", ".json"] diff --git a/lib/ansible/config/manager.py b/lib/ansible/config/manager.py index 52bd6547b33..f4a308d58e4 100644 --- a/lib/ansible/config/manager.py +++ b/lib/ansible/config/manager.py @@ -11,18 +11,18 @@ import os.path import sys import stat import tempfile +import typing as t from collections.abc import Mapping, Sequence from jinja2.nativetypes import NativeEnvironment -from ansible.errors import AnsibleOptionsError, AnsibleError, AnsibleRequiredOptionError +from ansible.errors import AnsibleOptionsError, AnsibleError, AnsibleUndefinedConfigEntry, AnsibleRequiredOptionError from ansible.module_utils.common.sentinel import Sentinel from ansible.module_utils.common.text.converters import to_text, to_bytes, to_native from ansible.module_utils.common.yaml import yaml_load from ansible.module_utils.six import string_types from ansible.module_utils.parsing.convert_bool import boolean from ansible.parsing.quoting import unquote -from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode from ansible.utils.path import cleanup_tmp_file, makedirs_safe, unfrackpath @@ -50,14 +50,18 @@ GALAXY_SERVER_ADDITIONAL = { } -def _get_entry(plugin_type, plugin_name, config): - """ construct entry for requested config """ - entry = '' +def _get_config_label(plugin_type: str, plugin_name: str, config: str) -> str: + """Return a label for the given config.""" + entry = f'{config!r}' + if plugin_type: - entry += 'plugin_type: %s ' % plugin_type + entry += ' for' + if plugin_name: - entry += 'plugin: %s ' % plugin_name - entry += 'setting: %s ' % config + entry += f' {plugin_name!r}' + + entry += f' {plugin_type} plugin' + return entry @@ -107,8 +111,8 @@ def ensure_type(value, value_type, origin=None, origin_ftype=None): value = int_part else: errmsg = 'int' - except decimal.DecimalException as e: - raise ValueError from e + except decimal.DecimalException: + errmsg = 'int' elif value_type == 'float': if not isinstance(value, float): @@ -167,7 +171,7 @@ def ensure_type(value, value_type, origin=None, origin_ftype=None): errmsg = 'dictionary' elif value_type in ('str', 'string'): - if isinstance(value, (string_types, AnsibleVaultEncryptedUnicode, bool, int, float, complex)): + if isinstance(value, (string_types, bool, int, float, complex)): value = to_text(value, errors='surrogate_or_strict') if origin_ftype and origin_ftype == 'ini': value = unquote(value) @@ -175,13 +179,13 @@ def ensure_type(value, value_type, origin=None, origin_ftype=None): errmsg = 'string' # defaults to string type - elif isinstance(value, (string_types, AnsibleVaultEncryptedUnicode)): + elif isinstance(value, (string_types)): value = to_text(value, errors='surrogate_or_strict') if origin_ftype and origin_ftype == 'ini': value = unquote(value) if errmsg: - raise ValueError(f'Invalid type provided for "{errmsg}": {value!r}') + raise ValueError(f'Invalid type provided for {errmsg!r}: {value!r}') return to_text(value, errors='surrogate_or_strict', nonstring='passthru') @@ -369,6 +373,7 @@ class ConfigManager(object): # template default values if possible # NOTE: cannot use is_template due to circular dep try: + # FIXME: This really should be using an immutable sandboxed native environment, not just native environment t = NativeEnvironment().from_string(value) value = t.render(variables) except Exception: @@ -494,10 +499,6 @@ class ConfigManager(object): self.WARNINGS.add(u'value for config entry {0} contains invalid characters, ignoring...'.format(to_text(name))) continue if temp_value is not None: # only set if entry is defined in container - # inline vault variables should be converted to a text string - if isinstance(temp_value, AnsibleVaultEncryptedUnicode): - temp_value = to_text(temp_value, errors='surrogate_or_strict') - value = temp_value origin = name @@ -515,10 +516,14 @@ class ConfigManager(object): keys=keys, variables=variables, direct=direct) except AnsibleError: raise - except Exception as e: - raise AnsibleError("Unhandled exception when retrieving %s:\n%s" % (config, to_native(e)), orig_exc=e) + except Exception as ex: + raise AnsibleError(f"Unhandled exception when retrieving {config!r}.") from ex return value + def get_config_default(self, config: str, plugin_type: str | None = None, plugin_name: str | None = None) -> t.Any: + """Return the default value for the specified configuration.""" + return self.get_configuration_definitions(plugin_type, plugin_name)[config]['default'] + def get_config_value_and_origin(self, config, cfile=None, plugin_type=None, plugin_name=None, keys=None, variables=None, direct=None): """ Given a config key figure out the actual value and report on the origin of the settings """ if cfile is None: @@ -623,22 +628,21 @@ class ConfigManager(object): if value is None: if defs[config].get('required', False): if not plugin_type or config not in INTERNAL_DEFS.get(plugin_type, {}): - raise AnsibleRequiredOptionError("No setting was provided for required configuration %s" % - to_native(_get_entry(plugin_type, plugin_name, config))) + raise AnsibleRequiredOptionError(f"Required config {_get_config_label(plugin_type, plugin_name, config)} not provided.") else: origin = 'default' value = self.template_default(defs[config].get('default'), variables) + try: # ensure correct type, can raise exceptions on mismatched types value = ensure_type(value, defs[config].get('type'), origin=origin, origin_ftype=origin_ftype) - except ValueError as e: + except ValueError as ex: if origin.startswith('env:') and value == '': # this is empty env var for non string so we can set to default origin = 'default' value = ensure_type(defs[config].get('default'), defs[config].get('type'), origin=origin, origin_ftype=origin_ftype) else: - raise AnsibleOptionsError('Invalid type for configuration option %s (from %s): %s' % - (to_native(_get_entry(plugin_type, plugin_name, config)).strip(), origin, to_native(e))) + raise AnsibleOptionsError(f'Config {_get_config_label(plugin_type, plugin_name, config)} from {origin!r} has an invalid value.') from ex # deal with restricted values if value is not None and 'choices' in defs[config] and defs[config]['choices'] is not None: @@ -661,14 +665,14 @@ class ConfigManager(object): else: valid = defs[config]['choices'] - raise AnsibleOptionsError('Invalid value "%s" for configuration option "%s", valid values are: %s' % - (value, to_native(_get_entry(plugin_type, plugin_name, config)), valid)) + raise AnsibleOptionsError(f'Invalid value {value!r} for config {_get_config_label(plugin_type, plugin_name, config)}.', + help_text=f'Valid values are: {valid}') # deal with deprecation of the setting if 'deprecated' in defs[config] and origin != 'default': self.DEPRECATED.append((config, defs[config].get('deprecated'))) else: - raise AnsibleError('Requested entry (%s) was not defined in configuration.' % to_native(_get_entry(plugin_type, plugin_name, config))) + raise AnsibleUndefinedConfigEntry(f'No config definition exists for {_get_config_label(plugin_type, plugin_name, config)}.') return value, origin diff --git a/lib/ansible/constants.py b/lib/ansible/constants.py index af60053a3dd..baa6bf6f8d6 100644 --- a/lib/ansible/constants.py +++ b/lib/ansible/constants.py @@ -166,7 +166,6 @@ INTERNAL_STATIC_VARS = frozenset( "inventory_hostname_short", "groups", "group_names", - "omit", "hostvars", "playbook_dir", "play_hosts", diff --git a/lib/ansible/errors/__init__.py b/lib/ansible/errors/__init__.py index 31ee4bdf1da..d3536459cfb 100644 --- a/lib/ansible/errors/__init__.py +++ b/lib/ansible/errors/__init__.py @@ -1,38 +1,34 @@ # (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import annotations -import re +import enum import traceback +import sys +import types +import typing as t from collections.abc import Sequence -from ansible.errors.yaml_strings import ( - YAML_COMMON_DICT_ERROR, - YAML_COMMON_LEADING_TAB_ERROR, - YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR, - YAML_COMMON_UNBALANCED_QUOTES_ERROR, - YAML_COMMON_UNQUOTED_COLON_ERROR, - YAML_COMMON_UNQUOTED_VARIABLE_ERROR, - YAML_POSITION_DETAILS, - YAML_AND_SHORTHAND_ERROR, -) -from ansible.module_utils.common.text.converters import to_native, to_text +from json import JSONDecodeError + +from ansible.module_utils.common.text.converters import to_text +from ..module_utils.datatag import native_type_name +from ansible._internal._datatag import _tags +from .._internal._errors import _utils + + +class ExitCode(enum.IntEnum): + SUCCESS = 0 # used by TQM, must be bit-flag safe + GENERIC_ERROR = 1 # used by TQM, must be bit-flag safe + HOST_FAILED = 2 # TQM-sourced, must be bit-flag safe + HOST_UNREACHABLE = 4 # TQM-sourced, must be bit-flag safe + PARSER_ERROR = 4 # FIXME: CLI-sourced, conflicts with HOST_UNREACHABLE + INVALID_CLI_OPTION = 5 + UNICODE_ERROR = 6 # obsolete, no longer used + KEYBOARD_INTERRUPT = 99 + UNKNOWN_ERROR = 250 class AnsibleError(Exception): @@ -44,257 +40,271 @@ class AnsibleError(Exception): Usage: - raise AnsibleError('some message here', obj=obj, show_content=True) + raise AnsibleError('some message here', obj=obj) - Where "obj" is some subclass of ansible.parsing.yaml.objects.AnsibleBaseYAMLObject, - which should be returned by the DataLoader() class. + Where "obj" may be tagged with Origin to provide context for error messages. """ - def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=False, orig_exc=None): - super(AnsibleError, self).__init__(message) + _exit_code = ExitCode.GENERIC_ERROR + _default_message = '' + _default_help_text: str | None = None + _include_cause_message = True + """ + When `True`, the exception message will be augmented with cause message(s). + Subclasses doing complex error analysis can disable this to take responsibility for reporting cause messages as needed. + """ + + def __init__( + self, + message: str = "", + obj: t.Any = None, + show_content: bool = True, + suppress_extended_error: bool | types.EllipsisType = ..., + orig_exc: BaseException | None = None, + help_text: str | None = None, + ) -> None: + # DTFIX-FUTURE: these fallback cases mask incorrect use of AnsibleError.message, what should we do? + if message is None: + message = '' + elif not isinstance(message, str): + message = str(message) + + if self._default_message and message: + message = _utils.concat_message(self._default_message, message) + elif self._default_message: + message = self._default_message + elif not message: + message = f'Unexpected {type(self).__name__} error.' + + super().__init__(message) self._show_content = show_content - self._suppress_extended_error = suppress_extended_error - self._message = to_native(message) + self._message = message + self._help_text_value = help_text or self._default_help_text self.obj = obj + + # deprecated: description='deprecate support for orig_exc, callers should use `raise ... from` only' core_version='2.23' + # deprecated: description='remove support for orig_exc' core_version='2.27' self.orig_exc = orig_exc - @property - def message(self): - # we import this here to prevent an import loop problem, - # since the objects code also imports ansible.errors - from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject + if suppress_extended_error is not ...: + from ..utils.display import Display - message = [self._message] + if suppress_extended_error: + self._show_content = False - # Add from previous exceptions - if self.orig_exc: - message.append('. %s' % to_native(self.orig_exc)) + Display().deprecated( + msg=f"The `suppress_extended_error` argument to `{type(self).__name__}` is deprecated. Use `show_content=False` instead.", + version="2.23", + ) - # Add from yaml to give specific file/line no - if isinstance(self.obj, AnsibleBaseYAMLObject): - extended_error = self._get_extended_error() - if extended_error and not self._suppress_extended_error: - message.append( - '\n\n%s' % to_native(extended_error) - ) + @property + def _original_message(self) -> str: + return self._message - return ''.join(message) + @property + def message(self) -> str: + """ + If `include_cause_message` is False, return the original message. + Otherwise, return the original message with cause message(s) appended, stopping on (and including) the first non-AnsibleError. + The recursion is due to `AnsibleError.__str__` calling this method, which uses `str` on child exceptions to create the cause message. + Recursion stops on the first non-AnsibleError since those exceptions do not implement the custom `__str__` behavior. + """ + return _utils.get_chained_message(self) @message.setter - def message(self, val): + def message(self, val) -> None: self._message = val - def __str__(self): - return self.message + @property + def _formatted_source_context(self) -> str | None: + with _utils.RedactAnnotatedSourceContext.when(not self._show_content): + if source_context := _utils.SourceContext.from_value(self.obj): + return str(source_context) - def __repr__(self): - return self.message + return None - def _get_error_lines_from_file(self, file_name, line_number): - """ - Returns the line in the file which corresponds to the reported error - location, as well as the line preceding it (if the error did not - occur on the first line), to provide context to the error. - """ + @property + def _help_text(self) -> str | None: + return self._help_text_value - target_line = '' - prev_line = '' + @_help_text.setter + def _help_text(self, value: str | None) -> None: + self._help_text_value = value - with open(file_name, 'r') as f: - lines = f.readlines() + def __str__(self) -> str: + return self.message - # In case of a YAML loading error, PyYAML will report the very last line - # as the location of the error. Avoid an index error here in order to - # return a helpful message. - file_length = len(lines) - if line_number >= file_length: - line_number = file_length - 1 + def __getstate__(self) -> dict[str, t.Any]: + """Augment object.__getstate__ to preserve additional values not represented in BaseException.__dict__.""" + state = t.cast(dict[str, t.Any], super().__getstate__()) + state.update( + args=self.args, + __cause__=self.__cause__, + __context__=self.__context__, + __suppress_context__=self.__suppress_context__, + ) - # If target_line contains only whitespace, move backwards until - # actual code is found. If there are several empty lines after target_line, - # the error lines would just be blank, which is not very helpful. - target_line = lines[line_number] - while not target_line.strip(): - line_number -= 1 - target_line = lines[line_number] + return state - if line_number > 0: - prev_line = lines[line_number - 1] + def __reduce__(self) -> tuple[t.Callable, tuple[type], dict[str, t.Any]]: + """ + Enable copy/pickle of AnsibleError derived types by correcting for BaseException's ancient C __reduce__ impl that: - return (target_line, prev_line) + * requires use of a type constructor with positional args + * assumes positional args are passed through from the derived type __init__ to BaseException.__init__ unmodified + * does not propagate args/__cause__/__context__/__suppress_context__ - def _get_extended_error(self): + NOTE: This does not preserve the dunder attributes on non-AnsibleError derived cause/context exceptions. + As a result, copy/pickle will discard chained exceptions after the first non-AnsibleError cause/context. """ - Given an object reporting the location of the exception in a file, return - detailed information regarding it including: + return type(self).__new__, (type(self),), self.__getstate__() - * the line which caused the error as well as the one preceding it - * causes and suggested remedies for common syntax errors - If this error was created with show_content=False, the reporting of content - is suppressed, as the file contents may be sensitive (ie. vault data). - """ +class AnsibleUndefinedConfigEntry(AnsibleError): + """The requested config entry is not defined.""" + - error_message = '' - - try: - (src_file, line_number, col_number) = self.obj.ansible_pos - error_message += YAML_POSITION_DETAILS % (src_file, line_number, col_number) - if src_file not in ('', '') and self._show_content: - (target_line, prev_line) = self._get_error_lines_from_file(src_file, line_number - 1) - target_line = to_text(target_line) - prev_line = to_text(prev_line) - if target_line: - stripped_line = target_line.replace(" ", "") - - # Check for k=v syntax in addition to YAML syntax and set the appropriate error position, - # arrow index - if re.search(r'\w+(\s+)?=(\s+)?[\w/-]+', prev_line): - error_position = prev_line.rstrip().find('=') - arrow_line = (" " * error_position) + "^ here" - error_message = YAML_POSITION_DETAILS % (src_file, line_number - 1, error_position + 1) - error_message += "\nThe offending line appears to be:\n\n%s\n%s\n\n" % (prev_line.rstrip(), arrow_line) - error_message += YAML_AND_SHORTHAND_ERROR - else: - arrow_line = (" " * (col_number - 1)) + "^ here" - error_message += "\nThe offending line appears to be:\n\n%s\n%s\n%s\n" % (prev_line.rstrip(), target_line.rstrip(), arrow_line) - - # TODO: There may be cases where there is a valid tab in a line that has other errors. - if '\t' in target_line: - error_message += YAML_COMMON_LEADING_TAB_ERROR - # common error/remediation checking here: - # check for unquoted vars starting lines - if ('{{' in target_line and '}}' in target_line) and ('"{{' not in target_line or "'{{" not in target_line): - error_message += YAML_COMMON_UNQUOTED_VARIABLE_ERROR - # check for common dictionary mistakes - elif ":{{" in stripped_line and "}}" in stripped_line: - error_message += YAML_COMMON_DICT_ERROR - # check for common unquoted colon mistakes - elif (len(target_line) and - len(target_line) > 1 and - len(target_line) > col_number and - target_line[col_number] == ":" and - target_line.count(':') > 1): - error_message += YAML_COMMON_UNQUOTED_COLON_ERROR - # otherwise, check for some common quoting mistakes - else: - # FIXME: This needs to split on the first ':' to account for modules like lineinfile - # that may have lines that contain legitimate colons, e.g., line: 'i ALL= (ALL) NOPASSWD: ALL' - # and throw off the quote matching logic. - parts = target_line.split(":") - if len(parts) > 1: - middle = parts[1].strip() - match = False - unbalanced = False - - if middle.startswith("'") and not middle.endswith("'"): - match = True - elif middle.startswith('"') and not middle.endswith('"'): - match = True - - if (len(middle) > 0 and - middle[0] in ['"', "'"] and - middle[-1] in ['"', "'"] and - target_line.count("'") > 2 or - target_line.count('"') > 2): - unbalanced = True - - if match: - error_message += YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR - if unbalanced: - error_message += YAML_COMMON_UNBALANCED_QUOTES_ERROR - - except (IOError, TypeError): - error_message += '\n(could not open file to display line)' - except IndexError: - error_message += '\n(specified line no longer in file, maybe it changed?)' - - return error_message +class AnsibleTaskError(AnsibleError): + """Task execution failed; provides contextual information about the task.""" + + _default_message = 'Task failed.' class AnsiblePromptInterrupt(AnsibleError): - """User interrupt""" + """User interrupt.""" class AnsiblePromptNoninteractive(AnsibleError): - """Unable to get user input""" + """Unable to get user input.""" class AnsibleAssertionError(AnsibleError, AssertionError): - """Invalid assertion""" - pass + """Invalid assertion.""" class AnsibleOptionsError(AnsibleError): - """ bad or incomplete options passed """ - pass + """Invalid options were passed.""" + + # FIXME: This exception is used for many non-CLI related errors. + # The few cases which are CLI related should really be handled by argparse instead, at which point the exit code here can be removed. + _exit_code = ExitCode.INVALID_CLI_OPTION class AnsibleRequiredOptionError(AnsibleOptionsError): - """ bad or incomplete options passed """ - pass + """Bad or incomplete options passed.""" class AnsibleParserError(AnsibleError): - """ something was detected early that is wrong about a playbook or data file """ - pass + """A playbook or data file could not be parsed.""" + + _exit_code = ExitCode.PARSER_ERROR + + +class AnsibleFieldAttributeError(AnsibleParserError): + """Errors caused during field attribute processing.""" + + +class AnsibleJSONParserError(AnsibleParserError): + """JSON-specific parsing failure wrapping an exception raised by the JSON parser.""" + + _default_message = 'JSON parsing failed.' + _include_cause_message = False # hide the underlying cause message, it's included by `handle_exception` as needed + + @classmethod + def handle_exception(cls, exception: Exception, origin: _tags.Origin) -> t.NoReturn: + if isinstance(exception, JSONDecodeError): + origin = origin.replace(line_num=exception.lineno, col_num=exception.colno) + + message = str(exception) + + error = cls(message, obj=origin) + + raise error from exception class AnsibleInternalError(AnsibleError): - """ internal safeguards tripped, something happened in the code that should never happen """ - pass + """Internal safeguards tripped, something happened in the code that should never happen.""" class AnsibleRuntimeError(AnsibleError): - """ ansible had a problem while running a playbook """ - pass + """Ansible had a problem while running a playbook.""" class AnsibleModuleError(AnsibleRuntimeError): - """ a module failed somehow """ - pass + """A module failed somehow.""" class AnsibleConnectionFailure(AnsibleRuntimeError): - """ the transport / connection_plugin had a fatal error """ - pass + """The transport / connection_plugin had a fatal error.""" class AnsibleAuthenticationFailure(AnsibleConnectionFailure): - """invalid username/password/key""" - pass + """Invalid username/password/key.""" + + _default_message = "Failed to authenticate." class AnsibleCallbackError(AnsibleRuntimeError): - """ a callback failure """ - pass + """A callback failure.""" class AnsibleTemplateError(AnsibleRuntimeError): - """A template related error""" - pass + """A template related error.""" + + +class TemplateTrustCheckFailedError(AnsibleTemplateError): + """Raised when processing was requested on an untrusted template or expression.""" + + _default_message = 'Encountered untrusted template or expression.' + _default_help_text = ('Templates and expressions must be defined by trusted sources such as playbooks or roles, ' + 'not untrusted sources such as module results.') + +class AnsibleTemplateTransformLimitError(AnsibleTemplateError): + """The internal template transform limit was exceeded.""" -class AnsibleFilterError(AnsibleTemplateError): - """ a templating failure """ - pass + _default_message = "Template transform limit exceeded." -class AnsibleLookupError(AnsibleTemplateError): - """ a lookup failure """ - pass +class AnsibleTemplateSyntaxError(AnsibleTemplateError): + """A syntax error was encountered while parsing a Jinja template or expression.""" + + +class AnsibleBrokenConditionalError(AnsibleTemplateError): + """A broken conditional with non-boolean result was used.""" + + _default_help_text = 'Broken conditionals can be temporarily allowed with the `ALLOW_BROKEN_CONDITIONALS` configuration option.' class AnsibleUndefinedVariable(AnsibleTemplateError): - """ a templating failure """ - pass + """An undefined variable was encountered while processing a template or expression.""" + + +class AnsibleValueOmittedError(AnsibleTemplateError): + """ + Raised when the result of a template operation was the Omit singleton. This exception purposely does + not derive from AnsibleError to avoid elision of the traceback, since uncaught errors of this type always + indicate a bug. + """ + + _default_message = "A template was resolved to an Omit scalar." + _default_help_text = "Callers must be prepared to handle this value. This is most likely a bug in the code requesting templating." + + +class AnsibleTemplatePluginError(AnsibleTemplateError): + """An error sourced by a template plugin (lookup/filter/test).""" + + +# deprecated: description='add deprecation warnings for these aliases' core_version='2.23' +AnsibleFilterError = AnsibleTemplatePluginError +AnsibleLookupError = AnsibleTemplatePluginError class AnsibleFileNotFound(AnsibleRuntimeError): - """ a file missing failure """ + """A file missing failure.""" - def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=False, orig_exc=None, paths=None, file_name=None): + def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, paths=None, file_name=None): self.file_name = file_name self.paths = paths @@ -322,10 +332,9 @@ class AnsibleFileNotFound(AnsibleRuntimeError): # DO NOT USE as they will probably be removed soon. # We will port the action modules in our tree to use a context manager instead. class AnsibleAction(AnsibleRuntimeError): - """ Base Exception for Action plugin flow control """ - - def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=False, orig_exc=None, result=None): + """Base Exception for Action plugin flow control.""" + def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None): super(AnsibleAction, self).__init__(message=message, obj=obj, show_content=show_content, suppress_extended_error=suppress_extended_error, orig_exc=orig_exc) if result is None: @@ -335,54 +344,87 @@ class AnsibleAction(AnsibleRuntimeError): class AnsibleActionSkip(AnsibleAction): - """ an action runtime skip""" + """An action runtime skip.""" - def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=False, orig_exc=None, result=None): + def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None): super(AnsibleActionSkip, self).__init__(message=message, obj=obj, show_content=show_content, suppress_extended_error=suppress_extended_error, orig_exc=orig_exc, result=result) self.result.update({'skipped': True, 'msg': message}) class AnsibleActionFail(AnsibleAction): - """ an action runtime failure""" - def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=False, orig_exc=None, result=None): + """An action runtime failure.""" + + def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None): super(AnsibleActionFail, self).__init__(message=message, obj=obj, show_content=show_content, suppress_extended_error=suppress_extended_error, orig_exc=orig_exc, result=result) - self.result.update({'failed': True, 'msg': message, 'exception': traceback.format_exc()}) + + result_overrides = {'failed': True, 'msg': message} + # deprecated: description='use sys.exception()' python_version='3.11' + if sys.exc_info()[1]: # DTFIX-RELEASE: remove this hack once TaskExecutor is no longer shucking AnsibleActionFail and returning its result + result_overrides['exception'] = traceback.format_exc() + + self.result.update(result_overrides) class _AnsibleActionDone(AnsibleAction): - """ an action runtime early exit""" - pass + """An action runtime early exit.""" class AnsiblePluginError(AnsibleError): - """ base class for Ansible plugin-related errors that do not need AnsibleError contextual data """ + """Base class for Ansible plugin-related errors that do not need AnsibleError contextual data.""" + def __init__(self, message=None, plugin_load_context=None): super(AnsiblePluginError, self).__init__(message) self.plugin_load_context = plugin_load_context class AnsiblePluginRemovedError(AnsiblePluginError): - """ a requested plugin has been removed """ - pass + """A requested plugin has been removed.""" class AnsiblePluginCircularRedirect(AnsiblePluginError): - """a cycle was detected in plugin redirection""" - pass + """A cycle was detected in plugin redirection.""" class AnsibleCollectionUnsupportedVersionError(AnsiblePluginError): - """a collection is not supported by this version of Ansible""" - pass + """A collection is not supported by this version of Ansible.""" -class AnsibleFilterTypeError(AnsibleTemplateError, TypeError): - """ a Jinja filter templating failure due to bad type""" - pass +class AnsibleTypeError(AnsibleRuntimeError, TypeError): + """Ansible-augmented TypeError subclass.""" class AnsiblePluginNotFound(AnsiblePluginError): - """ Indicates we did not find an Ansible plugin """ - pass + """Indicates we did not find an Ansible plugin.""" + + +class AnsibleConditionalError(AnsibleRuntimeError): + """Errors related to failed conditional expression evaluation.""" + + +class AnsibleVariableTypeError(AnsibleRuntimeError): + """An error due to attempted storage of an unsupported variable type.""" + + @classmethod + def from_value(cls, *, obj: t.Any) -> t.Self: + # avoid an incorrect error message when `obj` is a type + type_name = type(obj).__name__ if isinstance(obj, type) else native_type_name(obj) + + return cls(message=f'Type {type_name!r} is unsupported for variable storage.', obj=obj) + + +def __getattr__(name: str) -> t.Any: + """Inject import-time deprecation warnings.""" + from ..utils.display import Display + + if name == 'AnsibleFilterTypeError': + Display().deprecated( + msg="Importing 'AnsibleFilterTypeError' is deprecated.", + help_text=f"Import {AnsibleTypeError.__name__!r} instead.", + version="2.23", + ) + + return AnsibleTypeError + + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/lib/ansible/errors/yaml_strings.py b/lib/ansible/errors/yaml_strings.py deleted file mode 100644 index cc5cfb6c45a..00000000000 --- a/lib/ansible/errors/yaml_strings.py +++ /dev/null @@ -1,138 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations - -__all__ = [ - 'YAML_SYNTAX_ERROR', - 'YAML_POSITION_DETAILS', - 'YAML_COMMON_DICT_ERROR', - 'YAML_COMMON_UNQUOTED_VARIABLE_ERROR', - 'YAML_COMMON_UNQUOTED_COLON_ERROR', - 'YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR', - 'YAML_COMMON_UNBALANCED_QUOTES_ERROR', -] - -YAML_SYNTAX_ERROR = """\ -Syntax Error while loading YAML. - %s""" - -YAML_POSITION_DETAILS = """\ -The error appears to be in '%s': line %s, column %s, but may -be elsewhere in the file depending on the exact syntax problem. -""" - -YAML_COMMON_DICT_ERROR = """\ -This one looks easy to fix. YAML thought it was looking for the start of a -hash/dictionary and was confused to see a second "{". Most likely this was -meant to be an ansible template evaluation instead, so we have to give the -parser a small hint that we wanted a string instead. The solution here is to -just quote the entire value. - -For instance, if the original line was: - - app_path: {{ base_path }}/foo - -It should be written as: - - app_path: "{{ base_path }}/foo" -""" - -YAML_COMMON_UNQUOTED_VARIABLE_ERROR = """\ -We could be wrong, but this one looks like it might be an issue with -missing quotes. Always quote template expression brackets when they -start a value. For instance: - - with_items: - - {{ foo }} - -Should be written as: - - with_items: - - "{{ foo }}" -""" - -YAML_COMMON_UNQUOTED_COLON_ERROR = """\ -This one looks easy to fix. There seems to be an extra unquoted colon in the line -and this is confusing the parser. It was only expecting to find one free -colon. The solution is just add some quotes around the colon, or quote the -entire line after the first colon. - -For instance, if the original line was: - - copy: src=file.txt dest=/path/filename:with_colon.txt - -It can be written as: - - copy: src=file.txt dest='/path/filename:with_colon.txt' - -Or: - - copy: 'src=file.txt dest=/path/filename:with_colon.txt' -""" - -YAML_COMMON_PARTIALLY_QUOTED_LINE_ERROR = """\ -This one looks easy to fix. It seems that there is a value started -with a quote, and the YAML parser is expecting to see the line ended -with the same kind of quote. For instance: - - when: "ok" in result.stdout - -Could be written as: - - when: '"ok" in result.stdout' - -Or equivalently: - - when: "'ok' in result.stdout" -""" - -YAML_COMMON_UNBALANCED_QUOTES_ERROR = """\ -We could be wrong, but this one looks like it might be an issue with -unbalanced quotes. If starting a value with a quote, make sure the -line ends with the same set of quotes. For instance this arbitrary -example: - - foo: "bad" "wolf" - -Could be written as: - - foo: '"bad" "wolf"' -""" - -YAML_COMMON_LEADING_TAB_ERROR = """\ -There appears to be a tab character at the start of the line. - -YAML does not use tabs for formatting. Tabs should be replaced with spaces. - -For example: - - name: update tooling - vars: - version: 1.2.3 -# ^--- there is a tab there. - -Should be written as: - - name: update tooling - vars: - version: 1.2.3 -# ^--- all spaces here. -""" - -YAML_AND_SHORTHAND_ERROR = """\ -There appears to be both 'k=v' shorthand syntax and YAML in this task. \ -Only one syntax may be used. -""" diff --git a/lib/ansible/executor/action_write_locks.py b/lib/ansible/executor/action_write_locks.py deleted file mode 100644 index 2934615c508..00000000000 --- a/lib/ansible/executor/action_write_locks.py +++ /dev/null @@ -1,44 +0,0 @@ -# (c) 2016 - Red Hat, Inc. -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations - -import multiprocessing.synchronize - -from ansible.utils.multiprocessing import context as multiprocessing_context - -from ansible.module_utils.facts.system.pkg_mgr import PKG_MGRS - -if 'action_write_locks' not in globals(): - # Do not initialize this more than once because it seems to bash - # the existing one. multiprocessing must be reloading the module - # when it forks? - action_write_locks: dict[str | None, multiprocessing.synchronize.Lock] = dict() - - # Below is a Lock for use when we weren't expecting a named module. It gets used when an action - # plugin invokes a module whose name does not match with the action's name. Slightly less - # efficient as all processes with unexpected module names will wait on this lock - action_write_locks[None] = multiprocessing_context.Lock() - - # These plugins are known to be called directly by action plugins with names differing from the - # action plugin name. We precreate them here as an optimization. - # If a list of service managers is created in the future we can do the same for them. - mods = set(p['name'] for p in PKG_MGRS) - - mods.update(('copy', 'file', 'setup', 'slurp', 'stat')) - for mod_name in mods: - action_write_locks[mod_name] = multiprocessing_context.Lock() diff --git a/lib/ansible/executor/interpreter_discovery.py b/lib/ansible/executor/interpreter_discovery.py index f83f1c47d0a..bf168f922e2 100644 --- a/lib/ansible/executor/interpreter_discovery.py +++ b/lib/ansible/executor/interpreter_discovery.py @@ -9,7 +9,8 @@ from ansible import constants as C from ansible.errors import AnsibleError from ansible.utils.display import Display from ansible.utils.plugin_docs import get_versioned_doclink -from traceback import format_exc + +_FALLBACK_INTERPRETER = '/usr/bin/python3' display = Display() foundre = re.compile(r'FOUND(.*)ENDFOUND', flags=re.DOTALL) @@ -26,14 +27,14 @@ def discover_interpreter(action, interpreter_name, discovery_mode, task_vars): """Probe the target host for a Python interpreter from the `INTERPRETER_PYTHON_FALLBACK` list, returning the first found or `/usr/bin/python3` if none.""" host = task_vars.get('inventory_hostname', 'unknown') res = None - found_interpreters = [u'/usr/bin/python3'] # fallback value + found_interpreters = [_FALLBACK_INTERPRETER] # fallback value is_silent = discovery_mode.endswith('_silent') if discovery_mode.startswith('auto_legacy'): - action._discovery_deprecation_warnings.append(dict( + display.deprecated( msg=f"The '{discovery_mode}' option for 'INTERPRETER_PYTHON' now has the same effect as 'auto'.", version='2.21', - )) + ) try: bootstrap_python_list = C.config.get_config_value('INTERPRETER_PYTHON_FALLBACK', variables=task_vars) @@ -61,24 +62,26 @@ def discover_interpreter(action, interpreter_name, discovery_mode, task_vars): if not found_interpreters: if not is_silent: - action._discovery_warnings.append(u'No python interpreters found for ' - u'host {0} (tried {1})'.format(host, bootstrap_python_list)) + display.warning(msg=f'No python interpreters found for host {host!r} (tried {bootstrap_python_list!r}).') + # this is lame, but returning None or throwing an exception is uglier - return u'/usr/bin/python3' + return _FALLBACK_INTERPRETER except AnsibleError: raise except Exception as ex: if not is_silent: - action._discovery_warnings.append(f'Unhandled error in Python interpreter discovery for host {host}: {ex}') - display.debug(msg=f'Interpreter discovery traceback:\n{format_exc()}', host=host) + display.error_as_warning(msg=f'Unhandled error in Python interpreter discovery for host {host!r}.', exception=ex) + if res and res.get('stderr'): # the current ssh plugin implementation always has stderr, making coverage of the false case difficult display.vvv(msg=f"Interpreter discovery remote stderr:\n{res.get('stderr')}", host=host) if not is_silent: - action._discovery_warnings.append( - f"Host {host} is using the discovered Python interpreter at {found_interpreters[0]}, " - "but future installation of another Python interpreter could change the meaning of that path. " - f"See {get_versioned_doclink('reference_appendices/interpreter_discovery.html')} for more information." + display.warning( + msg=( + f"Host {host!r} is using the discovered Python interpreter at {found_interpreters[0]!r}, " + "but future installation of another Python interpreter could cause a different interpreter to be discovered." + ), + help_text=f"See {get_versioned_doclink('reference_appendices/interpreter_discovery.html')} for more information.", ) return found_interpreters[0] diff --git a/lib/ansible/executor/module_common.py b/lib/ansible/executor/module_common.py index 1a79c1a29bd..d98c70ee598 100644 --- a/lib/ansible/executor/module_common.py +++ b/lib/ansible/executor/module_common.py @@ -20,45 +20,76 @@ from __future__ import annotations import ast import base64 +import dataclasses import datetime import json import os +import pathlib +import pickle import shlex -import time import zipfile import re import pkgutil +import types import typing as t from ast import AST, Import, ImportFrom from io import BytesIO +from ansible._internal import _locking +from ansible._internal._datatag import _utils +from ansible.module_utils._internal import _dataclass_validation +from ansible.module_utils.common.messages import PluginInfo +from ansible.module_utils.common.yaml import yaml_load +from ansible._internal._datatag._tags import Origin +from ansible.module_utils.common.json import Direction, get_module_encoder from ansible.release import __version__, __author__ from ansible import constants as C from ansible.errors import AnsibleError from ansible.executor.interpreter_discovery import InterpreterDiscoveryRequiredError from ansible.executor.powershell import module_manifest as ps_manifest -from ansible.module_utils.common.json import AnsibleJSONEncoder from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native from ansible.plugins.become import BecomeBase from ansible.plugins.loader import module_utils_loader +from ansible._internal._templating._engine import TemplateOptions, TemplateEngine from ansible.template import Templar from ansible.utils.collection_loader._collection_finder import _get_collection_metadata, _nested_dict_get +from ansible.module_utils._internal import _json, _ansiballz +from ansible.module_utils import basic as _basic -# Must import strategy and use write_locks from there -# If we import write_locks directly then we end up binding a -# variable to the object and then it never gets updated. -from ansible.executor import action_write_locks +if t.TYPE_CHECKING: + from ansible import template as _template + from ansible.playbook.task import Task from ansible.utils.display import Display -from collections import namedtuple import importlib.util import importlib.machinery display = Display() -ModuleUtilsProcessEntry = namedtuple('ModuleUtilsProcessEntry', ['name_parts', 'is_ambiguous', 'has_redirected_child', 'is_optional']) + +@dataclasses.dataclass(frozen=True, order=True) +class _ModuleUtilsProcessEntry: + """Represents a module/module_utils item awaiting import analysis.""" + name_parts: tuple[str, ...] + is_ambiguous: bool = False + child_is_redirected: bool = False + is_optional: bool = False + + @classmethod + def from_module(cls, module: types.ModuleType, append: str | None = None) -> t.Self: + name = module.__name__ + + if append: + name += '.' + append + + return cls.from_module_name(name) + + @classmethod + def from_module_name(cls, module_name: str) -> t.Self: + return cls(tuple(module_name.split('.'))) + REPLACER = b"#<>" REPLACER_VERSION = b"\"<>\"" @@ -67,348 +98,45 @@ REPLACER_WINDOWS = b"# POWERSHELL_COMMON" REPLACER_JSONARGS = b"<>" REPLACER_SELINUX = b"<>" -# We could end up writing out parameters with unicode characters so we need to -# specify an encoding for the python source file -ENCODING_STRING = u'# -*- coding: utf-8 -*-' -b_ENCODING_STRING = b'# -*- coding: utf-8 -*-' - # module_common is relative to module_utils, so fix the path _MODULE_UTILS_PATH = os.path.join(os.path.dirname(__file__), '..', 'module_utils') +_SHEBANG_PLACEHOLDER = '# shebang placeholder' # ****************************************************************************** -ANSIBALLZ_TEMPLATE = u"""%(shebang)s -%(coding)s -_ANSIBALLZ_WRAPPER = True # For test-module.py script to tell this is a ANSIBALLZ_WRAPPER -# This code is part of Ansible, but is an independent component. -# The code in this particular templatable string, and this templatable string -# only, is BSD licensed. Modules which end up using this snippet, which is -# dynamically combined together by Ansible still belong to the author of the -# module, and they may assign their own license to the complete work. -# -# Copyright (c), James Cammarata, 2016 -# Copyright (c), Toshio Kuratomi, 2016 -# -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -def _ansiballz_main(): - import os - import os.path - - # Access to the working directory is required by Python when using pipelining, as well as for the coverage module. - # Some platforms, such as macOS, may not allow querying the working directory when using become to drop privileges. - try: - os.getcwd() - except OSError: - try: - os.chdir(os.path.expanduser('~')) - except OSError: - os.chdir('/') - -%(rlimit)s - - import sys - import __main__ - - # For some distros and python versions we pick up this script in the temporary - # directory. This leads to problems when the ansible module masks a python - # library that another import needs. We have not figured out what about the - # specific distros and python versions causes this to behave differently. - # - # Tested distros: - # Fedora23 with python3.4 Works - # Ubuntu15.10 with python2.7 Works - # Ubuntu15.10 with python3.4 Fails without this - # Ubuntu16.04.1 with python3.5 Fails without this - # To test on another platform: - # * use the copy module (since this shadows the stdlib copy module) - # * Turn off pipelining - # * Make sure that the destination file does not exist - # * ansible ubuntu16-test -m copy -a 'src=/etc/motd dest=/var/tmp/m' - # This will traceback in shutil. Looking at the complete traceback will show - # that shutil is importing copy which finds the ansible module instead of the - # stdlib module - scriptdir = None - try: - scriptdir = os.path.dirname(os.path.realpath(__main__.__file__)) - except (AttributeError, OSError): - # Some platforms don't set __file__ when reading from stdin - # OSX raises OSError if using abspath() in a directory we don't have - # permission to read (realpath calls abspath) - pass - - # Strip cwd from sys.path to avoid potential permissions issues - excludes = set(('', '.', scriptdir)) - sys.path = [p for p in sys.path if p not in excludes] - - import base64 - import runpy - import shutil - import tempfile - import zipfile - - if sys.version_info < (3,): - PY3 = False - else: - PY3 = True - - ZIPDATA = %(zipdata)r - - # Note: temp_path isn't needed once we switch to zipimport - def invoke_module(modlib_path, temp_path, json_params): - # When installed via setuptools (including python setup.py install), - # ansible may be installed with an easy-install.pth file. That file - # may load the system-wide install of ansible rather than the one in - # the module. sitecustomize is the only way to override that setting. - z = zipfile.ZipFile(modlib_path, mode='a') - - # py3: modlib_path will be text, py2: it's bytes. Need bytes at the end - sitecustomize = u'import sys\\nsys.path.insert(0,"%%s")\\n' %% modlib_path - sitecustomize = sitecustomize.encode('utf-8') - # Use a ZipInfo to work around zipfile limitation on hosts with - # clocks set to a pre-1980 year (for instance, Raspberry Pi) - zinfo = zipfile.ZipInfo() - zinfo.filename = 'sitecustomize.py' - zinfo.date_time = %(date_time)s - z.writestr(zinfo, sitecustomize) - z.close() - - # Put the zipped up module_utils we got from the controller first in the python path so that we - # can monkeypatch the right basic - sys.path.insert(0, modlib_path) - - # Monkeypatch the parameters into basic - from ansible.module_utils import basic - basic._ANSIBLE_ARGS = json_params -%(coverage)s - # Run the module! By importing it as '__main__', it thinks it is executing as a script - runpy.run_module(mod_name=%(module_fqn)r, init_globals=dict(_module_fqn=%(module_fqn)r, _modlib_path=modlib_path), - run_name='__main__', alter_sys=True) - - # Ansible modules must exit themselves - print('{"msg": "New-style module did not handle its own exit", "failed": true}') - sys.exit(1) - - def debug(command, zipped_mod, json_params): - # The code here normally doesn't run. It's only used for debugging on the - # remote machine. - # - # The subcommands in this function make it easier to debug ansiballz - # modules. Here's the basic steps: - # - # Run ansible with the environment variable: ANSIBLE_KEEP_REMOTE_FILES=1 and -vvv - # to save the module file remotely:: - # $ ANSIBLE_KEEP_REMOTE_FILES=1 ansible host1 -m ping -a 'data=october' -vvv - # - # Part of the verbose output will tell you where on the remote machine the - # module was written to:: - # [...] - # SSH: EXEC ssh -C -q -o ControlMaster=auto -o ControlPersist=60s -o KbdInteractiveAuthentication=no -o - # PreferredAuthentications=gssapi-with-mic,gssapi-keyex,hostbased,publickey -o PasswordAuthentication=no -o ConnectTimeout=10 -o - # ControlPath=/home/badger/.ansible/cp/ansible-ssh-%%h-%%p-%%r -tt rhel7 '/bin/sh -c '"'"'LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8 - # LC_MESSAGES=en_US.UTF-8 /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping'"'"'' - # [...] - # - # Login to the remote machine and run the module file via from the previous - # step with the explode subcommand to extract the module payload into - # source files:: - # $ ssh host1 - # $ /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping explode - # Module expanded into: - # /home/badger/.ansible/tmp/ansible-tmp-1461173408.08-279692652635227/ansible - # - # You can now edit the source files to instrument the code or experiment with - # different parameter values. When you're ready to run the code you've modified - # (instead of the code from the actual zipped module), use the execute subcommand like this:: - # $ /usr/bin/python /home/badger/.ansible/tmp/ansible-tmp-1461173013.93-9076457629738/ping execute - - # Okay to use __file__ here because we're running from a kept file - basedir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'debug_dir') - args_path = os.path.join(basedir, 'args') - - if command == 'explode': - # transform the ZIPDATA into an exploded directory of code and then - # print the path to the code. This is an easy way for people to look - # at the code on the remote machine for debugging it in that - # environment - z = zipfile.ZipFile(zipped_mod) - for filename in z.namelist(): - if filename.startswith('/'): - raise Exception('Something wrong with this module zip file: should not contain absolute paths') - - dest_filename = os.path.join(basedir, filename) - if dest_filename.endswith(os.path.sep) and not os.path.exists(dest_filename): - os.makedirs(dest_filename) - else: - directory = os.path.dirname(dest_filename) - if not os.path.exists(directory): - os.makedirs(directory) - f = open(dest_filename, 'wb') - f.write(z.read(filename)) - f.close() - - # write the args file - f = open(args_path, 'wb') - f.write(json_params) - f.close() - - print('Module expanded into:') - print('%%s' %% basedir) - exitcode = 0 - - elif command == 'execute': - # Execute the exploded code instead of executing the module from the - # embedded ZIPDATA. This allows people to easily run their modified - # code on the remote machine to see how changes will affect it. - - # Set pythonpath to the debug dir - sys.path.insert(0, basedir) - - # read in the args file which the user may have modified - with open(args_path, 'rb') as f: - json_params = f.read() - - # Monkeypatch the parameters into basic - from ansible.module_utils import basic - basic._ANSIBLE_ARGS = json_params - - # Run the module! By importing it as '__main__', it thinks it is executing as a script - runpy.run_module(mod_name=%(module_fqn)r, init_globals=None, run_name='__main__', alter_sys=True) - - # Ansible modules must exit themselves - print('{"msg": "New-style module did not handle its own exit", "failed": true}') - sys.exit(1) - - else: - print('WARNING: Unknown debug command. Doing nothing.') - exitcode = 0 - - return exitcode - - # - # See comments in the debug() method for information on debugging - # - - ANSIBALLZ_PARAMS = %(params)s - if PY3: - ANSIBALLZ_PARAMS = ANSIBALLZ_PARAMS.encode('utf-8') - try: - # There's a race condition with the controller removing the - # remote_tmpdir and this module executing under async. So we cannot - # store this in remote_tmpdir (use system tempdir instead) - # Only need to use [ansible_module]_payload_ in the temp_path until we move to zipimport - # (this helps ansible-test produce coverage stats) - temp_path = tempfile.mkdtemp(prefix='ansible_' + %(ansible_module)r + '_payload_') - - zipped_mod = os.path.join(temp_path, 'ansible_' + %(ansible_module)r + '_payload.zip') - - with open(zipped_mod, 'wb') as modlib: - modlib.write(base64.b64decode(ZIPDATA)) - - if len(sys.argv) == 2: - exitcode = debug(sys.argv[1], zipped_mod, ANSIBALLZ_PARAMS) - else: - # Note: temp_path isn't needed once we switch to zipimport - invoke_module(zipped_mod, temp_path, ANSIBALLZ_PARAMS) - finally: - try: - shutil.rmtree(temp_path) - except (NameError, OSError): - # tempdir creation probably failed - pass - sys.exit(exitcode) - -if __name__ == '__main__': - _ansiballz_main() -""" - -ANSIBALLZ_COVERAGE_TEMPLATE = """ - os.environ['COVERAGE_FILE'] = %(coverage_output)r + '=python-%%s=coverage' %% '.'.join(str(v) for v in sys.version_info[:2]) - - import atexit - - try: - import coverage - except ImportError: - print('{"msg": "Could not import `coverage` module.", "failed": true}') - sys.exit(1) - - cov = coverage.Coverage(config_file=%(coverage_config)r) - - def atexit_coverage(): - cov.stop() - cov.save() - atexit.register(atexit_coverage) +def _strip_comments(source: str) -> str: + # Strip comments and blank lines from the wrapper + buf = [] + for line in source.splitlines(): + l = line.strip() + if (not l or l.startswith('#')) and l != _SHEBANG_PLACEHOLDER: + line = '' + buf.append(line) + return '\n'.join(buf) - cov.start() -""" -ANSIBALLZ_COVERAGE_CHECK_TEMPLATE = """ - try: - if PY3: - import importlib.util - if importlib.util.find_spec('coverage') is None: - raise ImportError - else: - import imp - imp.find_module('coverage') - except ImportError: - print('{"msg": "Could not find `coverage` module.", "failed": true}') - sys.exit(1) -""" +def _read_ansiballz_code() -> str: + code = (pathlib.Path(__file__).parent.parent / '_internal/_ansiballz.py').read_text() -ANSIBALLZ_RLIMIT_TEMPLATE = """ - import resource + if not C.DEFAULT_KEEP_REMOTE_FILES: + # Keep comments when KEEP_REMOTE_FILES is set. That way users will see + # the comments with some nice usage instructions. + # Otherwise, strip comments for smaller over the wire size. + code = _strip_comments(code) - existing_soft, existing_hard = resource.getrlimit(resource.RLIMIT_NOFILE) + return code - # adjust soft limit subject to existing hard limit - requested_soft = min(existing_hard, %(rlimit_nofile)d) - if requested_soft != existing_soft: - try: - resource.setrlimit(resource.RLIMIT_NOFILE, (requested_soft, existing_hard)) - except ValueError: - # some platforms (eg macOS) lie about their hard limit - pass -""" +_ANSIBALLZ_CODE = _read_ansiballz_code() # read during startup to prevent individual workers from doing so -def _strip_comments(source): - # Strip comments and blank lines from the wrapper - buf = [] - for line in source.splitlines(): - l = line.strip() - if not l or l.startswith(u'#'): - continue - buf.append(line) - return u'\n'.join(buf) +def _get_ansiballz_code(shebang: str) -> str: + code = _ANSIBALLZ_CODE + code = code.replace(_SHEBANG_PLACEHOLDER, shebang) + return code -if C.DEFAULT_KEEP_REMOTE_FILES: - # Keep comments when KEEP_REMOTE_FILES is set. That way users will see - # the comments with some nice usage instructions - ACTIVE_ANSIBALLZ_TEMPLATE = ANSIBALLZ_TEMPLATE -else: - # ANSIBALLZ_TEMPLATE stripped of comments for smaller over the wire size - ACTIVE_ANSIBALLZ_TEMPLATE = _strip_comments(ANSIBALLZ_TEMPLATE) # dirname(dirname(dirname(site-packages/ansible/executor/module_common.py) == site-packages # Do this instead of getting site-packages from distutils.sysconfig so we work when we @@ -438,6 +166,7 @@ NEW_STYLE_PYTHON_MODULE_RE = re.compile( class ModuleDepFinder(ast.NodeVisitor): + # DTFIX-RELEASE: add support for ignoring imports with a "controller only" comment, this will allow replacing import_controller_module with standard imports def __init__(self, module_fqn, tree, is_pkg_init=False, *args, **kwargs): """ Walk the ast tree for the python module. @@ -584,7 +313,7 @@ def _slurp(path): return data -def _get_shebang(interpreter, task_vars, templar, args=tuple(), remote_is_local=False): +def _get_shebang(interpreter, task_vars, templar: _template.Templar, args=tuple(), remote_is_local=False): """ Handles the different ways ansible allows overriding the shebang target for a module. """ @@ -609,7 +338,8 @@ def _get_shebang(interpreter, task_vars, templar, args=tuple(), remote_is_local= elif C.config.get_configuration_definition(interpreter_config_key): interpreter_from_config = C.config.get_config_value(interpreter_config_key, variables=task_vars) - interpreter_out = templar.template(interpreter_from_config.strip()) + interpreter_out = templar._engine.template(_utils.str_problematic_strip(interpreter_from_config), + options=TemplateOptions(value_for_omit=C.config.get_config_default(interpreter_config_key))) # handle interpreter discovery if requested or empty interpreter was provided if not interpreter_out or interpreter_out in ['auto', 'auto_legacy', 'auto_silent', 'auto_legacy_silent']: @@ -627,7 +357,8 @@ def _get_shebang(interpreter, task_vars, templar, args=tuple(), remote_is_local= elif interpreter_config in task_vars: # for non python we consult vars for a possible direct override - interpreter_out = templar.template(task_vars.get(interpreter_config).strip()) + interpreter_out = templar._engine.template(_utils.str_problematic_strip(task_vars.get(interpreter_config)), + options=TemplateOptions(value_for_omit=None)) if not interpreter_out: # nothing matched(None) or in case someone configures empty string or empty intepreter @@ -806,12 +537,12 @@ class LegacyModuleUtilLocator(ModuleUtilLocatorBase): # find_spec needs the full module name self._info = info = importlib.machinery.PathFinder.find_spec('.'.join(name_parts), paths) - if info is not None and os.path.splitext(info.origin)[1] in importlib.machinery.SOURCE_SUFFIXES: + if info is not None and info.origin is not None and os.path.splitext(info.origin)[1] in importlib.machinery.SOURCE_SUFFIXES: self.is_package = info.origin.endswith('/__init__.py') path = info.origin else: return False - self.source_code = _slurp(path) + self.source_code = Origin(path=path).tag(_slurp(path)) return True @@ -846,9 +577,18 @@ class CollectionModuleUtilLocator(ModuleUtilLocatorBase): resource_base_path = os.path.join(*name_parts[3:]) src = None + # look for package_dir first, then module + src_path = to_native(os.path.join(resource_base_path, '__init__.py')) + + try: + collection_pkg = importlib.import_module(collection_pkg_name) + pkg_path = os.path.dirname(collection_pkg.__file__) + except (ImportError, AttributeError): + pkg_path = None + try: - src = pkgutil.get_data(collection_pkg_name, to_native(os.path.join(resource_base_path, '__init__.py'))) + src = pkgutil.get_data(collection_pkg_name, src_path) except ImportError: pass @@ -857,32 +597,113 @@ class CollectionModuleUtilLocator(ModuleUtilLocatorBase): if src is not None: # empty string is OK self.is_package = True else: + src_path = to_native(resource_base_path + '.py') + try: - src = pkgutil.get_data(collection_pkg_name, to_native(resource_base_path + '.py')) + src = pkgutil.get_data(collection_pkg_name, src_path) except ImportError: pass if src is None: # empty string is OK return False - self.source_code = src + # TODO: this feels brittle and funky; we should be able to more definitively assure the source path + + if pkg_path: + origin = Origin(path=os.path.join(pkg_path, src_path)) + else: + # DTFIX-RELEASE: not sure if this case is even reachable + origin = Origin(description=f'') + + self.source_code = origin.tag(src) return True def _get_module_utils_remainder_parts(self, name_parts): return name_parts[5:] # eg, foo.bar for ansible_collections.ns.coll.plugins.module_utils.foo.bar -def _make_zinfo(filename, date_time, zf=None): +def _make_zinfo(filename: str, date_time: datetime.datetime, zf: zipfile.ZipFile | None = None) -> zipfile.ZipInfo: zinfo = zipfile.ZipInfo( filename=filename, - date_time=date_time + date_time=date_time.utctimetuple()[:6], ) + if zf: zinfo.compress_type = zf.compression + return zinfo -def recursive_finder(name, module_fqn, module_data, zf, date_time=None): +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class ModuleMetadata: + @classmethod + def __post_init__(cls): + _dataclass_validation.inject_post_init_validation(cls) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class ModuleMetadataV1(ModuleMetadata): + serialization_profile: str + + +metadata_versions: dict[t.Any, type[ModuleMetadata]] = { + 1: ModuleMetadataV1, +} + + +def _get_module_metadata(module: ast.Module) -> ModuleMetadata: + # DTFIX-RELEASE: while module metadata works, this feature isn't fully baked and should be turned off before release + metadata_nodes: list[ast.Assign] = [] + + for node in module.body: + if isinstance(node, ast.Assign): + if len(node.targets) == 1: + target = node.targets[0] + + if isinstance(target, ast.Name): + if target.id == 'METADATA': + metadata_nodes.append(node) + + if not metadata_nodes: + return ModuleMetadataV1( + serialization_profile='legacy', + ) + + if len(metadata_nodes) > 1: + raise ValueError('Module METADATA must defined only once.') + + metadata_node = metadata_nodes[0] + + if not isinstance(metadata_node.value, ast.Constant): + raise TypeError(f'Module METADATA node must be {ast.Constant} not {type(metadata_node)}.') + + unparsed_metadata = metadata_node.value.value + + if not isinstance(unparsed_metadata, str): + raise TypeError(f'Module METADATA must be {str} not {type(unparsed_metadata)}.') + + try: + parsed_metadata = yaml_load(unparsed_metadata) + except Exception as ex: + raise ValueError('Module METADATA must be valid YAML.') from ex + + if not isinstance(parsed_metadata, dict): + raise TypeError(f'Module METADATA must parse to {dict} not {type(parsed_metadata)}.') + + schema_version = parsed_metadata.pop('schema_version', None) + + if not (metadata_type := metadata_versions.get(schema_version)): + raise ValueError(f'Module METADATA schema_version {schema_version} is unknown.') + + try: + metadata = metadata_type(**parsed_metadata) # type: ignore + except Exception as ex: + raise ValueError('Module METADATA is invalid.') from ex + + return metadata + + +def recursive_finder(name: str, module_fqn: str, module_data: str | bytes, zf: zipfile.ZipFile, date_time: datetime.datetime) -> ModuleMetadata: """ Using ModuleDepFinder, make sure we have all of the module_utils files that the module and its module_utils files needs. (no longer actually recursive) @@ -892,9 +713,6 @@ def recursive_finder(name, module_fqn, module_data, zf, date_time=None): :arg zf: An open :python:class:`zipfile.ZipFile` object that holds the Ansible module payload which we're assembling """ - if date_time is None: - date_time = time.gmtime()[:6] - # py_module_cache maps python module names to a tuple of the code in the module # and the pathname to the module. # Here we pre-load it with modules which we create without bothering to @@ -916,49 +734,57 @@ def recursive_finder(name, module_fqn, module_data, zf, date_time=None): module_utils_paths = [p for p in module_utils_loader._get_paths(subdirs=False) if os.path.isdir(p)] module_utils_paths.append(_MODULE_UTILS_PATH) - # Parse the module code and find the imports of ansible.module_utils - try: - tree = compile(module_data, '', 'exec', ast.PyCF_ONLY_AST) - except (SyntaxError, IndentationError) as e: - raise AnsibleError("Unable to import %s due to %s" % (name, e.msg)) - + tree = _compile_module_ast(name, module_data) + module_metadata = _get_module_metadata(tree) finder = ModuleDepFinder(module_fqn, tree) - # the format of this set is a tuple of the module name and whether or not the import is ambiguous as a module name - # or an attribute of a module (eg from x.y import z <-- is z a module or an attribute of x.y?) - modules_to_process = [ModuleUtilsProcessEntry(m, True, False, is_optional=m in finder.optional_imports) for m in finder.submodules] + if not isinstance(module_metadata, ModuleMetadataV1): + raise NotImplementedError() + + profile = module_metadata.serialization_profile - # HACK: basic is currently always required since module global init is currently tied up with AnsiballZ arg input - modules_to_process.append(ModuleUtilsProcessEntry(('ansible', 'module_utils', 'basic'), False, False, is_optional=False)) + # the format of this set is a tuple of the module name and whether the import is ambiguous as a module name + # or an attribute of a module (e.g. from x.y import z <-- is z a module or an attribute of x.y?) + modules_to_process = [_ModuleUtilsProcessEntry(m, True, False, is_optional=m in finder.optional_imports) for m in finder.submodules] + + # include module_utils that are always required + modules_to_process.extend(( + _ModuleUtilsProcessEntry.from_module(_ansiballz), + _ModuleUtilsProcessEntry.from_module(_basic), + _ModuleUtilsProcessEntry.from_module_name(_json.get_module_serialization_profile_module_name(profile, True)), + _ModuleUtilsProcessEntry.from_module_name(_json.get_module_serialization_profile_module_name(profile, False)), + )) + + module_info: ModuleUtilLocatorBase # we'll be adding new modules inline as we discover them, so just keep going til we've processed them all while modules_to_process: modules_to_process.sort() # not strictly necessary, but nice to process things in predictable and repeatable order - py_module_name, is_ambiguous, child_is_redirected, is_optional = modules_to_process.pop(0) + entry = modules_to_process.pop(0) - if py_module_name in py_module_cache: + if entry.name_parts in py_module_cache: # this is normal; we'll often see the same module imported many times, but we only need to process it once continue - if py_module_name[0:2] == ('ansible', 'module_utils'): - module_info = LegacyModuleUtilLocator(py_module_name, is_ambiguous=is_ambiguous, - mu_paths=module_utils_paths, child_is_redirected=child_is_redirected) - elif py_module_name[0] == 'ansible_collections': - module_info = CollectionModuleUtilLocator(py_module_name, is_ambiguous=is_ambiguous, - child_is_redirected=child_is_redirected, is_optional=is_optional) + if entry.name_parts[0:2] == ('ansible', 'module_utils'): + module_info = LegacyModuleUtilLocator(entry.name_parts, is_ambiguous=entry.is_ambiguous, + mu_paths=module_utils_paths, child_is_redirected=entry.child_is_redirected) + elif entry.name_parts[0] == 'ansible_collections': + module_info = CollectionModuleUtilLocator(entry.name_parts, is_ambiguous=entry.is_ambiguous, + child_is_redirected=entry.child_is_redirected, is_optional=entry.is_optional) else: # FIXME: dot-joined result display.warning('ModuleDepFinder improperly found a non-module_utils import %s' - % [py_module_name]) + % [entry.name_parts]) continue # Could not find the module. Construct a helpful error message. if not module_info.found: - if is_optional: + if entry.is_optional: # this was a best-effort optional import that we couldn't find, oh well, move along... continue # FIXME: use dot-joined candidate names - msg = 'Could not find imported module support code for {0}. Looked for ({1})'.format(module_fqn, module_info.candidate_names_joined) + msg = 'Could not find imported module support code for {0}. Looked for ({1})'.format(module_fqn, module_info.candidate_names_joined) raise AnsibleError(msg) # check the cache one more time with the module we actually found, since the name could be different than the input @@ -966,14 +792,9 @@ def recursive_finder(name, module_fqn, module_data, zf, date_time=None): if module_info.fq_name_parts in py_module_cache: continue - # compile the source, process all relevant imported modules - try: - tree = compile(module_info.source_code, '', 'exec', ast.PyCF_ONLY_AST) - except (SyntaxError, IndentationError) as e: - raise AnsibleError("Unable to import %s due to %s" % (module_info.fq_name_parts, e.msg)) - + tree = _compile_module_ast('.'.join(module_info.fq_name_parts), module_info.source_code) finder = ModuleDepFinder('.'.join(module_info.fq_name_parts), tree, module_info.is_package) - modules_to_process.extend(ModuleUtilsProcessEntry(m, True, False, is_optional=m in finder.optional_imports) + modules_to_process.extend(_ModuleUtilsProcessEntry(m, True, False, is_optional=m in finder.optional_imports) for m in finder.submodules if m not in py_module_cache) # we've processed this item, add it to the output list @@ -985,7 +806,7 @@ def recursive_finder(name, module_fqn, module_data, zf, date_time=None): accumulated_pkg_name.append(pkg) # we're accumulating this across iterations normalized_name = tuple(accumulated_pkg_name) # extra machinations to get a hashable type (list is not) if normalized_name not in py_module_cache: - modules_to_process.append(ModuleUtilsProcessEntry(normalized_name, False, module_info.redirected, is_optional=is_optional)) + modules_to_process.append(_ModuleUtilsProcessEntry(normalized_name, False, module_info.redirected, is_optional=entry.is_optional)) for py_module_name in py_module_cache: py_module_file_name = py_module_cache[py_module_name][1] @@ -997,8 +818,23 @@ def recursive_finder(name, module_fqn, module_data, zf, date_time=None): mu_file = to_text(py_module_file_name, errors='surrogate_or_strict') display.vvvvv("Including module_utils file %s" % mu_file) + return module_metadata + + +def _compile_module_ast(module_name: str, source_code: str | bytes) -> ast.Module: + origin = Origin.get_tag(source_code) or Origin.UNKNOWN + + # compile the source, process all relevant imported modules + try: + tree = t.cast(ast.Module, compile(source_code, str(origin), 'exec', ast.PyCF_ONLY_AST)) + except SyntaxError as ex: + raise AnsibleError(f"Unable to compile {module_name!r}.", obj=origin.replace(line_num=ex.lineno, col_num=ex.offset)) from ex + + return tree + def _is_binary(b_module_data): + """Heuristic to classify a file as binary by sniffing a 1k header; see https://stackoverflow.com/a/7392391""" textchars = bytearray(set([7, 8, 9, 10, 12, 13, 27]) | set(range(0x20, 0x100)) - set([0x7f])) start = b_module_data[:1024] return bool(start.translate(None, textchars)) @@ -1037,7 +873,7 @@ def _get_ansible_module_fqn(module_path): return remote_module_fqn -def _add_module_to_zip(zf, date_time, remote_module_fqn, b_module_data): +def _add_module_to_zip(zf: zipfile.ZipFile, date_time: datetime.datetime, remote_module_fqn: str, b_module_data: bytes) -> None: """Add a module from ansible or from an ansible collection into the module zip""" module_path_parts = remote_module_fqn.split('.') @@ -1048,6 +884,8 @@ def _add_module_to_zip(zf, date_time, remote_module_fqn, b_module_data): b_module_data ) + existing_paths: frozenset[str] + # Write the __init__.py's necessary to get there if module_path_parts[0] == 'ansible': # The ansible namespace is setup as part of the module_utils setup... @@ -1071,19 +909,53 @@ def _add_module_to_zip(zf, date_time, remote_module_fqn, b_module_data): ) +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class _BuiltModule: + """Payload required to execute an Ansible module, along with information required to do so.""" + b_module_data: bytes + module_style: t.Literal['binary', 'new', 'non_native_want_json', 'old'] + shebang: str | None + serialization_profile: str + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class _CachedModule: + """Cached Python module created by AnsiballZ.""" + + # DTFIX-RELEASE: secure this (locked down pickle, don't use pickle, etc.) + + zip_data: bytes + metadata: ModuleMetadata + + def dump(self, path: str) -> None: + temp_path = pathlib.Path(path + '-part') + + with temp_path.open('wb') as cache_file: + pickle.dump(self, cache_file) + + temp_path.rename(path) + + @classmethod + def load(cls, path: str) -> t.Self: + with pathlib.Path(path).open('rb') as cache_file: + return pickle.load(cache_file) + + def _find_module_utils( - module_name: str, - b_module_data: bytes, - module_path: str, - module_args: dict[object, object], - task_vars: dict[str, object], - templar: Templar, - module_compression: str, - async_timeout: int, - become_plugin: BecomeBase | None, - environment: dict[str, str], - remote_is_local: bool = False, -) -> tuple[bytes, t.Literal['binary', 'new', 'non_native_want_json', 'old'], str | None]: + *, + module_name: str, + plugin: PluginInfo, + b_module_data: bytes, + module_path: str, + module_args: dict[object, object], + task_vars: dict[str, object], + templar: Templar, + module_compression: str, + async_timeout: int, + become_plugin: BecomeBase | None, + environment: dict[str, str], + remote_is_local: bool = False +) -> _BuiltModule: """ Given the source of the module, convert it to a Jinja2 template to insert module code and return whether it's a new or old style module. @@ -1130,7 +1002,12 @@ def _find_module_utils( # Neither old-style, non_native_want_json nor binary modules should be modified # except for the shebang line (Done by modify_module) if module_style in ('old', 'non_native_want_json', 'binary'): - return b_module_data, module_style, shebang + return _BuiltModule( + b_module_data=b_module_data, + module_style=module_style, + shebang=shebang, + serialization_profile='legacy', + ) output = BytesIO() @@ -1146,15 +1023,9 @@ def _find_module_utils( remote_module_fqn = 'ansible.modules.%s' % module_name if module_substyle == 'python': - date_time = time.gmtime()[:6] - if date_time[0] < 1980: - date_string = datetime.datetime(*date_time, tzinfo=datetime.timezone.utc).strftime('%c') - raise AnsibleError(f'Cannot create zipfile due to pre-1980 configured date: {date_string}') - params = dict(ANSIBLE_MODULE_ARGS=module_args,) - try: - python_repred_params = repr(json.dumps(params, cls=AnsibleJSONEncoder, vault_to_text=True)) - except TypeError as e: - raise AnsibleError("Unable to pass options to module, they must be JSON serializable: %s" % to_native(e)) + date_time = datetime.datetime.now(datetime.timezone.utc) + if date_time.year < 1980: + raise AnsibleError(f'Cannot create zipfile due to pre-1980 configured date: {date_time}') try: compression_method = getattr(zipfile, module_compression) @@ -1165,27 +1036,21 @@ def _find_module_utils( lookup_path = os.path.join(C.DEFAULT_LOCAL_TMP, 'ansiballz_cache') # type: ignore[attr-defined] cached_module_filename = os.path.join(lookup_path, "%s-%s" % (remote_module_fqn, module_compression)) - zipdata = None + os.makedirs(os.path.dirname(cached_module_filename), exist_ok=True) + + zipdata: bytes | None = None + module_metadata: ModuleMetadata | None = None + # Optimization -- don't lock if the module has already been cached if os.path.exists(cached_module_filename): display.debug('ANSIBALLZ: using cached module: %s' % cached_module_filename) - with open(cached_module_filename, 'rb') as module_data: - zipdata = module_data.read() + cached_module = _CachedModule.load(cached_module_filename) + zipdata, module_metadata = cached_module.zip_data, cached_module.metadata else: - if module_name in action_write_locks.action_write_locks: - display.debug('ANSIBALLZ: Using lock for %s' % module_name) - lock = action_write_locks.action_write_locks[module_name] - else: - # If the action plugin directly invokes the module (instead of - # going through a strategy) then we don't have a cross-process - # Lock specifically for this module. Use the "unexpected - # module" lock instead - display.debug('ANSIBALLZ: Using generic lock for %s' % module_name) - lock = action_write_locks.action_write_locks[None] - display.debug('ANSIBALLZ: Acquiring lock') - with lock: - display.debug('ANSIBALLZ: Lock acquired: %s' % id(lock)) + lock_path = f'{cached_module_filename}.lock' + with _locking.named_mutex(lock_path): + display.debug(f'ANSIBALLZ: Lock acquired: {lock_path}') # Check that no other process has created this while we were # waiting for the lock if not os.path.exists(cached_module_filename): @@ -1195,7 +1060,7 @@ def _find_module_utils( zf = zipfile.ZipFile(zipoutput, mode='w', compression=compression_method) # walk the module imports, looking for module_utils to send- they'll be added to the zipfile - recursive_finder(module_name, remote_module_fqn, b_module_data, zf, date_time) + module_metadata = recursive_finder(module_name, remote_module_fqn, Origin(path=module_path).tag(b_module_data), zf, date_time) display.debug('ANSIBALLZ: Writing module into payload') _add_module_to_zip(zf, date_time, remote_module_fqn, b_module_data) @@ -1206,42 +1071,24 @@ def _find_module_utils( # Write the assembled module to a temp file (write to temp # so that no one looking for the file reads a partially # written file) - # - # FIXME: Once split controller/remote is merged, this can be simplified to - # os.makedirs(lookup_path, exist_ok=True) - if not os.path.exists(lookup_path): - try: - # Note -- if we have a global function to setup, that would - # be a better place to run this - os.makedirs(lookup_path) - except OSError: - # Multiple processes tried to create the directory. If it still does not - # exist, raise the original exception. - if not os.path.exists(lookup_path): - raise + os.makedirs(lookup_path, exist_ok=True) display.debug('ANSIBALLZ: Writing module') - with open(cached_module_filename + '-part', 'wb') as f: - f.write(zipdata) - - # Rename the file into its final position in the cache so - # future users of this module can read it off the - # filesystem instead of constructing from scratch. - display.debug('ANSIBALLZ: Renaming module') - os.rename(cached_module_filename + '-part', cached_module_filename) + cached_module = _CachedModule(zip_data=zipdata, metadata=module_metadata) + cached_module.dump(cached_module_filename) display.debug('ANSIBALLZ: Done creating module') - if zipdata is None: + if not zipdata: display.debug('ANSIBALLZ: Reading module after lock') # Another process wrote the file while we were waiting for # the write lock. Go ahead and read the data from disk # instead of re-creating it. try: - with open(cached_module_filename, 'rb') as f: - zipdata = f.read() + cached_module = _CachedModule.load(cached_module_filename) except IOError: raise AnsibleError('A different worker process failed to create module file. ' 'Look at traceback for that process for debugging information.') - zipdata = to_text(zipdata, errors='surrogate_or_strict') + + zipdata, module_metadata = cached_module.zip_data, cached_module.metadata o_interpreter, o_args = _extract_interpreter(b_module_data) if o_interpreter is None: @@ -1253,48 +1100,56 @@ def _find_module_utils( rlimit_nofile = C.config.get_config_value('PYTHON_MODULE_RLIMIT_NOFILE', variables=task_vars) if not isinstance(rlimit_nofile, int): - rlimit_nofile = int(templar.template(rlimit_nofile)) - - if rlimit_nofile: - rlimit = ANSIBALLZ_RLIMIT_TEMPLATE % dict( - rlimit_nofile=rlimit_nofile, - ) - else: - rlimit = '' + rlimit_nofile = int(templar._engine.template(rlimit_nofile, options=TemplateOptions(value_for_omit=0))) coverage_config = os.environ.get('_ANSIBLE_COVERAGE_CONFIG') if coverage_config: coverage_output = os.environ['_ANSIBLE_COVERAGE_OUTPUT'] - - if coverage_output: - # Enable code coverage analysis of the module. - # This feature is for internal testing and may change without notice. - coverage = ANSIBALLZ_COVERAGE_TEMPLATE % dict( - coverage_config=coverage_config, - coverage_output=coverage_output, - ) - else: - # Verify coverage is available without importing it. - # This will detect when a module would fail with coverage enabled with minimal overhead. - coverage = ANSIBALLZ_COVERAGE_CHECK_TEMPLATE else: - coverage = '' + coverage_output = None + + if not isinstance(module_metadata, ModuleMetadataV1): + raise NotImplementedError() - output.write(to_bytes(ACTIVE_ANSIBALLZ_TEMPLATE % dict( - zipdata=zipdata, + params = dict(ANSIBLE_MODULE_ARGS=module_args,) + encoder = get_module_encoder(module_metadata.serialization_profile, Direction.CONTROLLER_TO_MODULE) + try: + encoded_params = json.dumps(params, cls=encoder) + except TypeError as ex: + raise AnsibleError(f'Failed to serialize arguments for the {module_name!r} module.') from ex + + code = _get_ansiballz_code(shebang) + args = dict( + zipdata=to_text(zipdata), ansible_module=module_name, module_fqn=remote_module_fqn, - params=python_repred_params, - shebang=shebang, - coding=ENCODING_STRING, + params=encoded_params, + profile=module_metadata.serialization_profile, + plugin_info_dict=dataclasses.asdict(plugin), date_time=date_time, - coverage=coverage, - rlimit=rlimit, - ))) + coverage_config=coverage_config, + coverage_output=coverage_output, + rlimit_nofile=rlimit_nofile, + ) + + args_string = '\n'.join(f'{key}={value!r},' for key, value in args.items()) + + wrapper = f"""{code} + +if __name__ == "__main__": + _ansiballz_main( +{args_string} +) +""" + + output.write(to_bytes(wrapper)) + b_module_data = output.getvalue() elif module_substyle == 'powershell': + module_metadata = ModuleMetadataV1(serialization_profile='legacy') # DTFIX-FUTURE: support serialization profiles for PowerShell modules + # Powershell/winrm don't actually make use of shebang so we can # safely set this here. If we let the fallback code handle this # it can fail in the presence of the UTF8 BOM commonly added by @@ -1312,10 +1167,12 @@ def _find_module_utils( become_plugin=become_plugin, substyle=module_substyle, task_vars=task_vars, + profile=module_metadata.serialization_profile, ) elif module_substyle == 'jsonargs': - module_args_json = to_bytes(json.dumps(module_args, cls=AnsibleJSONEncoder, vault_to_text=True)) + encoder = get_module_encoder('legacy', Direction.CONTROLLER_TO_MODULE) + module_args_json = to_bytes(json.dumps(module_args, cls=encoder)) # these strings could be included in a third-party module but # officially they were included in the 'basic' snippet for new-style @@ -1338,7 +1195,19 @@ def _find_module_utils( facility = b'syslog.' + to_bytes(syslog_facility, errors='surrogate_or_strict') b_module_data = b_module_data.replace(b'syslog.LOG_USER', facility) - return (b_module_data, module_style, shebang) + module_metadata = ModuleMetadataV1(serialization_profile='legacy') + else: + module_metadata = ModuleMetadataV1(serialization_profile='legacy') + + if not isinstance(module_metadata, ModuleMetadataV1): + raise NotImplementedError(type(module_metadata)) + + return _BuiltModule( + b_module_data=b_module_data, + module_style=module_style, + shebang=shebang, + serialization_profile=module_metadata.serialization_profile, + ) def _extract_interpreter(b_module_data): @@ -1364,8 +1233,20 @@ def _extract_interpreter(b_module_data): return interpreter, args -def modify_module(module_name, module_path, module_args, templar, task_vars=None, module_compression='ZIP_STORED', async_timeout=0, - become_plugin=None, environment=None, remote_is_local=False): +def modify_module( + *, + module_name: str, + plugin: PluginInfo, + module_path, + module_args, + templar, + task_vars=None, + module_compression='ZIP_STORED', + async_timeout=0, + become_plugin=None, + environment=None, + remote_is_local=False, +) -> _BuiltModule: """ Used to insert chunks of code into modules before transfer rather than doing regular python imports. This allows for more efficient transfer in @@ -1394,22 +1275,31 @@ def modify_module(module_name, module_path, module_args, templar, task_vars=None # read in the module source b_module_data = f.read() - (b_module_data, module_style, shebang) = _find_module_utils( - module_name, - b_module_data, - module_path, - module_args, - task_vars, - templar, - module_compression, + module_bits = _find_module_utils( + module_name=module_name, + plugin=plugin, + b_module_data=b_module_data, + module_path=module_path, + module_args=module_args, + task_vars=task_vars, + templar=templar, + module_compression=module_compression, async_timeout=async_timeout, become_plugin=become_plugin, environment=environment, remote_is_local=remote_is_local, ) - if module_style == 'binary': - return (b_module_data, module_style, to_text(shebang, nonstring='passthru')) + b_module_data = module_bits.b_module_data + shebang = module_bits.shebang + + if module_bits.module_style == 'binary': + return _BuiltModule( + b_module_data=module_bits.b_module_data, + module_style=module_bits.module_style, + shebang=to_text(module_bits.shebang, nonstring='passthru'), + serialization_profile=module_bits.serialization_profile, + ) elif shebang is None: interpreter, args = _extract_interpreter(b_module_data) # No interpreter/shebang, assume a binary module? @@ -1423,15 +1313,20 @@ def modify_module(module_name, module_path, module_args, templar, task_vars=None if interpreter != new_interpreter: b_lines[0] = to_bytes(shebang, errors='surrogate_or_strict', nonstring='passthru') - if os.path.basename(interpreter).startswith(u'python'): - b_lines.insert(1, b_ENCODING_STRING) - b_module_data = b"\n".join(b_lines) - return (b_module_data, module_style, shebang) + return _BuiltModule( + b_module_data=b_module_data, + module_style=module_bits.module_style, + shebang=shebang, + serialization_profile=module_bits.serialization_profile, + ) + +def _get_action_arg_defaults(action: str, task: Task, templar: TemplateEngine) -> dict[str, t.Any]: + action_groups = task._parent._play._action_groups + defaults = task.module_defaults -def get_action_args_with_defaults(action, args, defaults, templar, action_groups=None): # Get the list of groups that contain this action if action_groups is None: msg = ( @@ -1444,7 +1339,7 @@ def get_action_args_with_defaults(action, args, defaults, templar, action_groups else: group_names = action_groups.get(action, []) - tmp_args = {} + tmp_args: dict[str, t.Any] = {} module_defaults = {} # Merge latest defaults into dict, since they are a list of dicts @@ -1452,18 +1347,20 @@ def get_action_args_with_defaults(action, args, defaults, templar, action_groups for default in defaults: module_defaults.update(default) - # module_defaults keys are static, but the values may be templated - module_defaults = templar.template(module_defaults) for default in module_defaults: if default.startswith('group/'): group_name = default.split('group/')[-1] if group_name in group_names: - tmp_args.update((module_defaults.get('group/%s' % group_name) or {}).copy()) + tmp_args.update(templar.resolve_to_container(module_defaults.get(f'group/{group_name}', {}))) # handle specific action defaults - tmp_args.update(module_defaults.get(action, {}).copy()) - - # direct args override all - tmp_args.update(args) + tmp_args.update(templar.resolve_to_container(module_defaults.get(action, {}))) return tmp_args + + +def _apply_action_arg_defaults(action: str, task: Task, action_args: dict[str, t.Any], templar: Templar) -> dict[str, t.Any]: + args = _get_action_arg_defaults(action, task, templar._engine) + args.update(action_args) + + return args diff --git a/lib/ansible/executor/play_iterator.py b/lib/ansible/executor/play_iterator.py index 54ed6ca3b1f..69d0b00b0e7 100644 --- a/lib/ansible/executor/play_iterator.py +++ b/lib/ansible/executor/play_iterator.py @@ -155,9 +155,6 @@ class PlayIterator: setup_block.run_once = False setup_task = Task(block=setup_block) setup_task.action = 'gather_facts' - # TODO: hardcoded resolution here, but should use actual resolution code in the end, - # in case of 'legacy' mismatch - setup_task.resolved_action = 'ansible.builtin.gather_facts' setup_task.name = 'Gathering Facts' setup_task.args = {} @@ -255,7 +252,6 @@ class PlayIterator: self.set_state_for_host(host.name, s) display.debug("done getting next task for host %s" % host.name) - display.debug(" ^ task is: %s" % task) display.debug(" ^ state is: %s" % s) return (s, task) @@ -292,7 +288,7 @@ class PlayIterator: if (gathering == 'implicit' and implied) or \ (gathering == 'explicit' and boolean(self._play.gather_facts, strict=False)) or \ - (gathering == 'smart' and implied and not (self._variable_manager._fact_cache.get(host.name, {}).get('_ansible_facts_gathered', False))): + (gathering == 'smart' and implied and not self._variable_manager._facts_gathered_for_host(host.name)): # The setup block is always self._blocks[0], as we inject it # during the play compilation in __init__ above. setup_block = self._blocks[0] @@ -450,8 +446,7 @@ class PlayIterator: # skip implicit flush_handlers if there are no handlers notified if ( task.implicit - and task.action in C._ACTION_META - and task.args.get('_raw_params', None) == 'flush_handlers' + and task._get_meta() == 'flush_handlers' and ( # the state store in the `state` variable could be a nested state, # notifications are always stored in the top level state, get it here diff --git a/lib/ansible/executor/playbook_executor.py b/lib/ansible/executor/playbook_executor.py index 468c4bdc709..78329df342f 100644 --- a/lib/ansible/executor/playbook_executor.py +++ b/lib/ansible/executor/playbook_executor.py @@ -26,7 +26,7 @@ from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.parsing.convert_bool import boolean from ansible.plugins.loader import become_loader, connection_loader, shell_loader from ansible.playbook import Playbook -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.utils.helpers import pct_to_int from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path, _get_collection_playbook_path @@ -132,7 +132,7 @@ class PlaybookExecutor: # Allow variables to be used in vars_prompt fields. all_vars = self._variable_manager.get_vars(play=play) - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) setattr(play, 'vars_prompt', templar.template(play.vars_prompt)) # FIXME: this should be a play 'sub object' like loop_control @@ -158,7 +158,7 @@ class PlaybookExecutor: # Post validate so any play level variables are templated all_vars = self._variable_manager.get_vars(play=play) - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) play.post_validate(templar) if context.CLIARGS['syntax']: diff --git a/lib/ansible/executor/powershell/module_manifest.py b/lib/ansible/executor/powershell/module_manifest.py index 716ea122624..490fd3b6c2b 100644 --- a/lib/ansible/executor/powershell/module_manifest.py +++ b/lib/ansible/executor/powershell/module_manifest.py @@ -12,11 +12,13 @@ import pkgutil import secrets import re import typing as t + from importlib import import_module from ansible.module_utils.compat.version import LooseVersion from ansible import constants as C +from ansible.module_utils.common.json import Direction, get_module_encoder from ansible.errors import AnsibleError, AnsibleFileNotFound from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.plugins.become import BecomeBase @@ -351,6 +353,7 @@ def _create_powershell_wrapper( become_plugin: BecomeBase | None, substyle: t.Literal["powershell", "script"], task_vars: dict[str, t.Any], + profile: str, ) -> bytes: """Creates module or script wrapper for PowerShell. @@ -369,8 +372,6 @@ def _create_powershell_wrapper( :return: The input data for bootstrap_wrapper.ps1 as a byte string. """ - # creates the manifest/wrapper used in PowerShell/C# modules to enable - # things like become and async - this is also called in action/script.py actions: list[_ManifestAction] = [] finder = PSModuleDepFinder() @@ -405,7 +406,7 @@ def _create_powershell_wrapper( 'Variables': [ { 'Name': 'complex_args', - 'Value': module_args, + 'Value': _prepare_module_args(module_args, profile), 'Scope': 'Global', }, ], @@ -540,3 +541,13 @@ def _get_bootstrap_input( bootstrap_input = json.dumps(bootstrap_manifest, ensure_ascii=True) exec_input = json.dumps(dataclasses.asdict(manifest)) return f"{bootstrap_input}\n\0\0\0\0\n{exec_input}".encode() + + +def _prepare_module_args(module_args: dict[str, t.Any], profile: str) -> dict[str, t.Any]: + """ + Serialize the module args with the specified profile and deserialize them with the Python built-in JSON decoder. + This is used to facilitate serializing module args with a different encoder (profile) than is used for the manifest. + """ + encoder = get_module_encoder(profile, Direction.CONTROLLER_TO_MODULE) + + return json.loads(json.dumps(module_args, cls=encoder)) diff --git a/lib/ansible/executor/process/worker.py b/lib/ansible/executor/process/worker.py index 55eda53c855..96fb2c687cf 100644 --- a/lib/ansible/executor/process/worker.py +++ b/lib/ansible/executor/process/worker.py @@ -28,6 +28,7 @@ import typing as t from multiprocessing.queues import Queue from ansible import context +from ansible._internal import _task from ansible.errors import AnsibleConnectionFailure, AnsibleError from ansible.executor.task_executor import TaskExecutor from ansible.executor.task_queue_manager import FinalQueue, STDIN_FILENO, STDOUT_FILENO, STDERR_FILENO @@ -39,6 +40,7 @@ from ansible.playbook.task import Task from ansible.playbook.play_context import PlayContext from ansible.plugins.loader import init_plugin_loader from ansible.utils.context_objects import CLIArgs +from ansible.plugins.action import ActionBase from ansible.utils.display import Display from ansible.utils.multiprocessing import context as multiprocessing_context from ansible.vars.manager import VariableManager @@ -189,7 +191,8 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin display.set_queue(self._final_q) self._detach() try: - return self._run() + with _task.TaskContext(self._task): + return self._run() except BaseException: self._hard_exit(traceback.format_exc()) @@ -259,20 +262,17 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin executor_result, task_fields=self._task.dump_attrs(), ) - except Exception as e: - display.debug(f'failed to send task result ({e}), sending surrogate result') - self._final_q.send_task_result( - self._host.name, - self._task._uuid, - # Overriding the task result, to represent the failure - { - 'failed': True, - 'msg': f'{e}', - 'exception': traceback.format_exc(), - }, - # The failure pickling may have been caused by the task attrs, omit for safety - {}, - ) + except Exception as ex: + try: + raise AnsibleError("Task result omitted due to queue send failure.") from ex + except Exception as ex_wrapper: + self._final_q.send_task_result( + self._host.name, + self._task._uuid, + ActionBase.result_dict_from_exception(ex_wrapper), # Overriding the task result, to represent the failure + {}, # The failure pickling may have been caused by the task attrs, omit for safety + ) + display.debug("done sending task result for task %s" % self._task._uuid) except AnsibleConnectionFailure: diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index d7b64edb232..ef292dac9f7 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -10,29 +10,39 @@ import pathlib import signal import subprocess import sys + import traceback +import typing as t from ansible import constants as C from ansible.cli import scripts -from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip +from ansible.errors import ( + AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip, AnsibleTaskError, + AnsibleValueOmittedError, +) from ansible.executor.task_result import TaskResult -from ansible.executor.module_common import get_action_args_with_defaults +from ansible._internal._datatag import _utils +from ansible.module_utils._internal._plugin_exec_context import PluginExecContext +from ansible.module_utils.common.messages import Detail, WarningSummary, DeprecationSummary +from ansible.module_utils.datatag import native_type_name +from ansible._internal._datatag._tags import TrustedAsTemplate from ansible.module_utils.parsing.convert_bool import boolean -from ansible.module_utils.six import binary_type from ansible.module_utils.common.text.converters import to_text, to_native from ansible.module_utils.connection import write_to_stream from ansible.module_utils.six import string_types -from ansible.playbook.conditional import Conditional from ansible.playbook.task import Task from ansible.plugins import get_plugin_class +from ansible.plugins.action import ActionBase from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader +from ansible._internal._templating._jinja_plugins import _invoke_lookup, _DirectCall +from ansible._internal._templating._engine import TemplateEngine from ansible.template import Templar from ansible.utils.collection_loader import AnsibleCollectionConfig -from ansible.utils.listify import listify_lookup_plugin_terms -from ansible.utils.unsafe_proxy import to_unsafe_text, wrap_var -from ansible.vars.clean import namespace_facts, clean_facts -from ansible.utils.display import Display +from ansible.utils.display import Display, _DeferredWarningContext from ansible.utils.vars import combine_vars +from ansible.vars.clean import namespace_facts, clean_facts +from ansible.vars.manager import _deprecate_top_level_fact +from ansible._internal._errors import _captured display = Display() @@ -60,29 +70,6 @@ def task_timeout(signum, frame): raise TaskTimeoutError(frame=frame) -def remove_omit(task_args, omit_token): - """ - Remove args with a value equal to the ``omit_token`` recursively - to align with now having suboptions in the argument_spec - """ - - if not isinstance(task_args, dict): - return task_args - - new_args = {} - for i in task_args.items(): - if i[1] == omit_token: - continue - elif isinstance(i[1], dict): - new_args[i[0]] = remove_omit(i[1], omit_token) - elif isinstance(i[1], list): - new_args[i[0]] = [remove_omit(v, omit_token) for v in i[1]] - else: - new_args[i[0]] = i[1] - - return new_args - - class TaskExecutor: """ @@ -92,7 +79,7 @@ class TaskExecutor: class. """ - def __init__(self, host, task, job_vars, play_context, loader, shared_loader_obj, final_q, variable_manager): + def __init__(self, host, task: Task, job_vars, play_context, loader, shared_loader_obj, final_q, variable_manager): self._host = host self._task = task self._job_vars = job_vars @@ -103,6 +90,7 @@ class TaskExecutor: self._final_q = final_q self._variable_manager = variable_manager self._loop_eval_error = None + self._task_templar = TemplateEngine(loader=self._loader, variables=self._job_vars) self._task.squash() @@ -134,10 +122,14 @@ class TaskExecutor: # loop through the item results and set the global changed/failed/skipped result flags based on any item. res['skipped'] = True for item in item_results: + if item.get('_ansible_no_log'): + res.update(_ansible_no_log=True) # ensure no_log processing recognizes at least one item needs to be censored + if 'changed' in item and item['changed'] and not res.get('changed'): res['changed'] = True if res['skipped'] and ('skipped' not in item or ('skipped' in item and not item['skipped'])): res['skipped'] = False + # FIXME: normalize `failed` to a bool, warn if the action/module used non-bool if 'failed' in item and item['failed']: item_ignore = item.pop('_ansible_ignore_errors') if not res.get('failed'): @@ -164,6 +156,7 @@ class TaskExecutor: res[array] = res[array] + item[array] del item[array] + # FIXME: normalize `failed` to a bool, warn if the action/module used non-bool if not res.get('failed', False): res['msg'] = 'All items completed' if res['skipped']: @@ -172,43 +165,23 @@ class TaskExecutor: res = dict(changed=False, skipped=True, skipped_reason='No items in the list', results=[]) else: display.debug("calling self._execute()") - res = self._execute() + res = self._execute(self._task_templar, self._job_vars) display.debug("_execute() done") # make sure changed is set in the result, if it's not present if 'changed' not in res: res['changed'] = False - def _clean_res(res, errors='surrogate_or_strict'): - if isinstance(res, binary_type): - return to_unsafe_text(res, errors=errors) - elif isinstance(res, dict): - for k in res: - try: - res[k] = _clean_res(res[k], errors=errors) - except UnicodeError: - if k == 'diff': - # If this is a diff, substitute a replacement character if the value - # is undecodable as utf8. (Fix #21804) - display.warning("We were unable to decode all characters in the module return data." - " Replaced some in an effort to return as much as possible") - res[k] = _clean_res(res[k], errors='surrogate_then_replace') - else: - raise - elif isinstance(res, list): - for idx, item in enumerate(res): - res[idx] = _clean_res(item, errors=errors) - return res - - display.debug("dumping result to json") - res = _clean_res(res) - display.debug("done dumping result, returning") return res - except AnsibleError as e: - return dict(failed=True, msg=wrap_var(to_text(e, nonstring='simplerepr')), _ansible_no_log=self._play_context.no_log) - except Exception as e: - return dict(failed=True, msg=wrap_var('Unexpected failure during module execution: %s' % (to_native(e, nonstring='simplerepr'))), - exception=to_text(traceback.format_exc()), stdout='', _ansible_no_log=self._play_context.no_log) + except Exception as ex: + result = ActionBase.result_dict_from_exception(ex) + + self._task.update_result_no_log(self._task_templar, result) + + if not isinstance(ex, AnsibleError): + result.update(msg=f'Unexpected failure during task execution: {result["msg"]}') + + return result finally: try: self._connection.close() @@ -217,7 +190,7 @@ class TaskExecutor: except Exception as e: display.debug(u"error closing connection: %s" % to_text(e)) - def _get_loop_items(self): + def _get_loop_items(self) -> list[t.Any] | None: """ Loads a lookup plugin to handle the with_* portion of a task (if specified), and returns the items result. @@ -230,49 +203,51 @@ class TaskExecutor: if self._loader.get_basedir() not in self._job_vars['ansible_search_path']: self._job_vars['ansible_search_path'].append(self._loader.get_basedir()) - templar = Templar(loader=self._loader, variables=self._job_vars) items = None if self._task.loop_with: - if self._task.loop_with in self._shared_loader_obj.lookup_loader: - - # TODO: hardcoded so it fails for non first_found lookups, but this should be generalized for those that don't do their own templating - # lookup prop/attribute? - fail = bool(self._task.loop_with != 'first_found') - loop_terms = listify_lookup_plugin_terms(terms=self._task.loop, templar=templar, fail_on_undefined=fail, convert_bare=False) - - # get lookup - mylookup = self._shared_loader_obj.lookup_loader.get(self._task.loop_with, loader=self._loader, templar=templar) - - # give lookup task 'context' for subdir (mostly needed for first_found) - for subdir in ['template', 'var', 'file']: # TODO: move this to constants? - if subdir in self._task.action: - break - setattr(mylookup, '_subdir', subdir + 's') + templar = self._task_templar + terms = self._task.loop + + if isinstance(terms, str): + terms = templar.resolve_to_container(_utils.str_problematic_strip(terms)) + + if not isinstance(terms, list): + terms = [terms] + + @_DirectCall.mark + def invoke_lookup() -> t.Any: + """Scope-capturing wrapper around _invoke_lookup to avoid functools.partial obscuring its usage from type-checking tools.""" + return _invoke_lookup( + plugin_name=self._task.loop_with, + lookup_terms=terms, + lookup_kwargs=dict(wantlist=True), + invoked_as_with=True, + ) - # run lookup - items = wrap_var(mylookup.run(terms=loop_terms, variables=self._job_vars, wantlist=True)) - else: - raise AnsibleError("Unexpected failure in finding the lookup named '%s' in the available lookup plugins" % self._task.loop_with) + # Smuggle a special wrapped lookup invocation in as a local variable for its exclusive use when being evaluated as `with_(lookup)`. + # This value will not be visible to other users of this templar or its `available_variables`. + items = templar.evaluate_expression(expression=TrustedAsTemplate().tag("invoke_lookup()"), local_variables=dict(invoke_lookup=invoke_lookup)) elif self._task.loop is not None: - items = templar.template(self._task.loop) + items = self._task_templar.template(self._task.loop) + if not isinstance(items, list): raise AnsibleError( - "Invalid data passed to 'loop', it requires a list, got this instead: %s." - " Hint: If you passed a list/dict of just one element," - " try adding wantlist=True to your lookup invocation or use q/query instead of lookup." % items + f"The `loop` value must resolve to a 'list', not {native_type_name(items)!r}.", + help_text="Provide a list of items/templates, or a template resolving to a list.", + obj=self._task.loop, ) return items - def _run_loop(self, items): + def _run_loop(self, items: list[t.Any]) -> list[dict[str, t.Any]]: """ Runs the task with the loop items specified and collates the result into an array named 'results' which is inserted into the final result along with the item for which the loop ran. """ task_vars = self._job_vars - templar = Templar(loader=self._loader, variables=task_vars) + templar = TemplateEngine(loader=self._loader, variables=task_vars) self._task.loop_control.post_validate(templar=templar) @@ -281,17 +256,20 @@ class TaskExecutor: loop_pause = self._task.loop_control.pause extended = self._task.loop_control.extended extended_allitems = self._task.loop_control.extended_allitems + # ensure we always have a label - label = self._task.loop_control.label or '{{' + loop_var + '}}' + label = self._task.loop_control.label or templar.variable_name_as_template(loop_var) if loop_var in task_vars: - display.warning(u"%s: The loop variable '%s' is already in use. " - u"You should set the `loop_var` value in the `loop_control` option for the task" - u" to something else to avoid variable collisions and unexpected behavior." % (self._task, loop_var)) + display.warning( + msg=f"The loop variable {loop_var!r} is already in use.", + help_text="You should set the `loop_var` value in the `loop_control` option for the task " + "to something else to avoid variable collisions and unexpected behavior.", + obj=loop_var, + ) ran_once = False task_fields = None - no_log = False items_len = len(items) results = [] for item_index, item in enumerate(items): @@ -331,7 +309,7 @@ class TaskExecutor: ran_once = True try: - tmp_task = self._task.copy(exclude_parent=True, exclude_tasks=True) + tmp_task: Task = self._task.copy(exclude_parent=True, exclude_tasks=True) tmp_task._parent = self._task._parent tmp_play_context = self._play_context.copy() except AnsibleParserError as e: @@ -340,9 +318,11 @@ class TaskExecutor: # now we swap the internal task and play context with their copies, # execute, and swap them back so we can do the next iteration cleanly + # NB: this swap-a-dee-doo confuses some type checkers about the type of tmp_task/self._task (self._task, tmp_task) = (tmp_task, self._task) (self._play_context, tmp_play_context) = (tmp_play_context, self._play_context) - res = self._execute(variables=task_vars) + + res = self._execute(templar=templar, variables=task_vars) if self._task.register: # Ensure per loop iteration results are registered in case `_execute()` @@ -354,9 +334,6 @@ class TaskExecutor: (self._task, tmp_task) = (tmp_task, self._task) (self._play_context, tmp_play_context) = (tmp_play_context, self._play_context) - # update 'general no_log' based on specific no_log - no_log = no_log or tmp_task.no_log - # now update the result with the item info, and append the result # to the list of results res[loop_var] = item @@ -391,6 +368,7 @@ class TaskExecutor: task_fields=task_fields, ) + # FIXME: normalize `failed` to a bool, warn if the action/module used non-bool if tr.is_failed() or tr.is_unreachable(): self._final_q.send_callback('v2_runner_item_on_failed', tr) elif tr.is_skipped(): @@ -405,11 +383,14 @@ class TaskExecutor: # break loop if break_when conditions are met if self._task.loop_control and self._task.loop_control.break_when: - cond = Conditional(loader=self._loader) - cond.when = self._task.loop_control.get_validated_value( - 'break_when', self._task.loop_control.fattributes.get('break_when'), self._task.loop_control.break_when, templar + break_when = self._task.loop_control.get_validated_value( + 'break_when', + self._task.loop_control.fattributes.get('break_when'), + self._task.loop_control.break_when, + templar, ) - if cond.evaluate_conditional(templar, task_vars): + + if self._task._resolve_conditional(break_when, task_vars): # delete loop vars before exiting loop del task_vars[loop_var] break @@ -431,7 +412,6 @@ class TaskExecutor: if var in task_vars and var not in self._job_vars: del task_vars[var] - self._task.no_log = no_log # NOTE: run_once cannot contain loop vars because it's templated earlier also # This is saving the post-validated field from the last loop so the strategy can use the templated value post task execution self._task.run_once = task_fields.get('run_once') @@ -447,22 +427,50 @@ class TaskExecutor: # At the point this is executed it is safe to mutate self._task, # since `self._task` is either a copy referred to by `tmp_task` in `_run_loop` # or just a singular non-looped task - if delegated_host_name: - self._task.delegate_to = delegated_host_name - variables.update(delegated_vars) - def _execute(self, variables=None): + self._task.delegate_to = delegated_host_name # always override, since a templated result could be an omit (-> None) + variables.update(delegated_vars) + + def _execute(self, templar: TemplateEngine, variables: dict[str, t.Any]) -> dict[str, t.Any]: + result: dict[str, t.Any] + + with _DeferredWarningContext(variables=variables) as warning_ctx: + try: + # DTFIX-FUTURE: improve error handling to prioritize the earliest exception, turning the remaining ones into warnings + result = self._execute_internal(templar, variables) + self._apply_task_result_compat(result, warning_ctx) + _captured.AnsibleActionCapturedError.maybe_raise_on_result(result) + except Exception as ex: + try: + raise AnsibleTaskError(obj=self._task.get_ds()) from ex + except AnsibleTaskError as atex: + result = ActionBase.result_dict_from_exception(atex) + result.setdefault('changed', False) + + self._task.update_result_no_log(templar, result) + + # The warnings/deprecations in the result have already been captured in the _DeferredWarningContext by _apply_task_result_compat. + # The captured warnings/deprecations are a superset of the ones from the result, and may have been converted from a dict to a dataclass. + # These are then used to supersede the entries in the result. + + result.pop('warnings', None) + result.pop('deprecations', None) + + if warnings := warning_ctx.get_warnings(): + result.update(warnings=warnings) + + if deprecation_warnings := warning_ctx.get_deprecation_warnings(): + result.update(deprecations=deprecation_warnings) + + return result + + def _execute_internal(self, templar: TemplateEngine, variables: dict[str, t.Any]) -> dict[str, t.Any]: """ The primary workhorse of the executor system, this runs the task on the specified host (which may be the delegated_to host) and handles the retry/until and block rescue/always execution """ - if variables is None: - variables = self._job_vars - - templar = Templar(loader=self._loader, variables=variables) - self._calculate_delegate_to(templar, variables) context_validation_error = None @@ -497,18 +505,13 @@ class TaskExecutor: # skipping this task during the conditional evaluation step context_validation_error = e - no_log = self._play_context.no_log - # Evaluate the conditional (if any) for this task, which we do before running # the final task post-validation. We do this before the post validation due to # the fact that the conditional may specify that the task be skipped due to a # variable not being present which would otherwise cause validation to fail try: - conditional_result, false_condition = self._task.evaluate_conditional_with_result(templar, tempvars) - if not conditional_result: - display.debug("when evaluation is False, skipping this task") - return dict(changed=False, skipped=True, skip_reason='Conditional result was False', - false_condition=false_condition, _ansible_no_log=no_log) + if not self._task._resolve_conditional(self._task.when, tempvars, result_context=(rc := t.cast(dict[str, t.Any], {}))): + return dict(changed=False, skipped=True, skip_reason='Conditional result was False') | rc except AnsibleError as e: # loop error takes precedence if self._loop_eval_error is not None: @@ -524,22 +527,27 @@ class TaskExecutor: # if we ran into an error while setting up the PlayContext, raise it now, unless is known issue with delegation # and undefined vars (correct values are in cvars later on and connection plugins, if still error, blows up there) + + # DTFIX-RELEASE: this should probably be declaratively handled in post_validate (or better, get rid of play_context) if context_validation_error is not None: raiseit = True if self._task.delegate_to: - if isinstance(context_validation_error, AnsibleUndefinedVariable): - raiseit = False - elif isinstance(context_validation_error, AnsibleParserError): + if isinstance(context_validation_error, AnsibleParserError): # parser error, might be cause by undef too - orig_exc = getattr(context_validation_error, 'orig_exc', None) - if isinstance(orig_exc, AnsibleUndefinedVariable): + if isinstance(context_validation_error.__cause__, AnsibleUndefinedVariable): raiseit = False + elif isinstance(context_validation_error, AnsibleUndefinedVariable): + # DTFIX-RELEASE: should not be possible to hit this now (all are AnsibleFieldAttributeError)? + raiseit = False if raiseit: raise context_validation_error # pylint: disable=raising-bad-type # set templar to use temp variables until loop is evaluated templar.available_variables = tempvars + # Now we do final validation on the task, which sets all fields to their final values. + self._task.post_validate(templar=templar) + # if this task is a TaskInclude, we just return now with a success code so the # main thread can expand the task list for the given host if self._task.action in C._ACTION_INCLUDE_TASKS: @@ -548,7 +556,6 @@ class TaskExecutor: if not include_file: return dict(failed=True, msg="No include file was specified to the include") - include_file = templar.template(include_file) return dict(include=include_file, include_args=include_args) # if this task is a IncludeRole, we just return now with a success code so the main thread can expand the task list for the given host @@ -556,32 +563,9 @@ class TaskExecutor: include_args = self._task.args.copy() return dict(include_args=include_args) - # Now we do final validation on the task, which sets all fields to their final values. - try: - self._task.post_validate(templar=templar) - except AnsibleError: - raise - except Exception: - return dict(changed=False, failed=True, _ansible_no_log=no_log, exception=to_text(traceback.format_exc())) - if '_variable_params' in self._task.args: - variable_params = self._task.args.pop('_variable_params') - if isinstance(variable_params, dict): - if C.INJECT_FACTS_AS_VARS: - display.warning("Using a variable for a task's 'args' is unsafe in some situations " - "(see https://docs.ansible.com/ansible/devel/reference_appendices/faq.html#argsplat-unsafe)") - variable_params.update(self._task.args) - self._task.args = variable_params - else: - # if we didn't get a dict, it means there's garbage remaining after k=v parsing, just give up - # see https://github.com/ansible/ansible/issues/79862 - raise AnsibleError(f"invalid or malformed argument: '{variable_params}'") - - # update no_log to task value, now that we have it templated - no_log = self._task.no_log - # free tempvars up, not used anymore, cvars and vars_copy should be mainly used after this point # updating the original 'variables' at the end - tempvars = {} + del tempvars # setup cvars copy, used for all connection related templating if self._task.delegate_to: @@ -633,23 +617,7 @@ class TaskExecutor: cvars['ansible_python_interpreter'] = sys.executable # get handler - self._handler, module_context = self._get_action_handler_with_module_context(templar=templar) - - if module_context is not None: - module_defaults_fqcn = module_context.resolved_fqcn - else: - module_defaults_fqcn = self._task.resolved_action - - # Apply default params for action/module, if present - self._task.args = get_action_args_with_defaults( - module_defaults_fqcn, self._task.args, self._task.module_defaults, templar, - action_groups=self._task._parent._play._action_groups - ) - - # And filter out any fields which were set to default(omit), and got the omit token value - omit_token = variables.get('omit') - if omit_token is not None: - self._task.args = remove_omit(self._task.args, omit_token) + self._handler, _module_context = self._get_action_handler_with_module_context(templar=templar) retries = 1 # includes the default actual run + retries set by user/default if self._task.retries is not None: @@ -669,7 +637,10 @@ class TaskExecutor: if self._task.timeout: old_sig = signal.signal(signal.SIGALRM, task_timeout) signal.alarm(self._task.timeout) - result = self._handler.run(task_vars=vars_copy) + with PluginExecContext(self._handler): + result = self._handler.run(task_vars=vars_copy) + + # DTFIX-RELEASE: nuke this, it hides a lot of error detail- remove the active exception propagation hack from AnsibleActionFail at the same time except (AnsibleActionFail, AnsibleActionSkip) as e: return e.result except AnsibleConnectionFailure as e: @@ -684,12 +655,6 @@ class TaskExecutor: self._handler.cleanup() display.debug("handler run complete") - # propagate no log to result- the action can set this, so only overwrite it with the task's value if missing or falsey - result["_ansible_no_log"] = bool(no_log or result.get('_ansible_no_log', False)) - - if self._task.action not in C._ACTION_WITH_CLEAN_FACTS: - result = wrap_var(result) - # update the local copy of vars with the registered value, if specified, # or any facts which may have been generated by the module execution if self._task.register: @@ -713,26 +678,6 @@ class TaskExecutor: result, task_fields=self._task.dump_attrs())) - # ensure no log is preserved - result["_ansible_no_log"] = no_log - - # helper methods for use below in evaluating changed/failed_when - def _evaluate_changed_when_result(result): - if self._task.changed_when is not None and self._task.changed_when: - cond = Conditional(loader=self._loader) - cond.when = self._task.changed_when - result['changed'] = cond.evaluate_conditional(templar, vars_copy) - - def _evaluate_failed_when_result(result): - if self._task.failed_when: - cond = Conditional(loader=self._loader) - cond.when = self._task.failed_when - failed_when_result = cond.evaluate_conditional(templar, vars_copy) - result['failed_when_result'] = result['failed'] = failed_when_result - else: - failed_when_result = False - return failed_when_result - if 'ansible_facts' in result and self._task.action not in C._ACTION_DEBUG: if self._task.action in C._ACTION_WITH_CLEAN_FACTS: if self._task.delegate_to and self._task.delegate_facts: @@ -744,10 +689,11 @@ class TaskExecutor: vars_copy.update(result['ansible_facts']) else: # TODO: cleaning of facts should eventually become part of taskresults instead of vars - af = wrap_var(result['ansible_facts']) + af = result['ansible_facts'] vars_copy['ansible_facts'] = combine_vars(vars_copy.get('ansible_facts', {}), namespace_facts(af)) if C.INJECT_FACTS_AS_VARS: - vars_copy.update(clean_facts(af)) + cleaned_toplevel = {k: _deprecate_top_level_fact(v) for k, v in clean_facts(af).items()} + vars_copy.update(cleaned_toplevel) # set the failed property if it was missing. if 'failed' not in result: @@ -765,9 +711,6 @@ class TaskExecutor: if 'changed' not in result: result['changed'] = False - if self._task.action not in C._ACTION_WITH_CLEAN_FACTS: - result = wrap_var(result) - # re-update the local copy of vars with the registered value, if specified, # or any facts which may have been generated by the module execution # This gives changed/failed_when access to additional recently modified @@ -780,18 +723,30 @@ class TaskExecutor: if 'skipped' not in result: condname = 'changed' + # DTFIX-RELEASE: error normalization has not yet occurred; this means that the expressions used for until/failed_when/changed_when/break_when + # and when (for loops on the second and later iterations) cannot see the normalized error shapes. This, and the current impl of the expression + # handling here causes a number of problems: + # * any error in one of the post-task exec expressions is silently ignored and detail lost (eg: `failed_when: syntax ERROR @$123`) + # * they cannot reliably access error/warning details, since many of those details are inaccessible until the error normalization occurs + # * error normalization includes `msg` if present, and supplies `unknown error` if not; this leads to screwy results on True failed_when if + # `msg` is present, eg: `{debug: {}, failed_when: True` -> "Task failed: Action failed: Hello world!" + # * detail about failed_when is lost; any error details from the task could potentially be grafted in/preserved if error normalization was done + try: - _evaluate_changed_when_result(result) + if self._task.changed_when is not None and self._task.changed_when: + result['changed'] = self._task._resolve_conditional(self._task.changed_when, vars_copy) + condname = 'failed' - _evaluate_failed_when_result(result) + + if self._task.failed_when: + result['failed_when_result'] = result['failed'] = self._task._resolve_conditional(self._task.failed_when, vars_copy) + except AnsibleError as e: result['failed'] = True result['%s_when_result' % condname] = to_text(e) if retries > 1: - cond = Conditional(loader=self._loader) - cond.when = self._task.until or [not result['failed']] - if cond.evaluate_conditional(templar, vars_copy): + if self._task._resolve_conditional(self._task.until or [not result['failed']], vars_copy): break else: # no conditional check, or it failed, so sleep for the specified time @@ -816,9 +771,6 @@ class TaskExecutor: result['attempts'] = retries - 1 result['failed'] = True - if self._task.action not in C._ACTION_WITH_CLEAN_FACTS: - result = wrap_var(result) - # do the final update of the local variables here, for both registered # values and any facts which may have been created if self._task.register: @@ -829,10 +781,12 @@ class TaskExecutor: variables.update(result['ansible_facts']) else: # TODO: cleaning of facts should eventually become part of taskresults instead of vars - af = wrap_var(result['ansible_facts']) + af = result['ansible_facts'] variables['ansible_facts'] = combine_vars(variables.get('ansible_facts', {}), namespace_facts(af)) if C.INJECT_FACTS_AS_VARS: - variables.update(clean_facts(af)) + # DTFIX-FUTURE: why is this happening twice, esp since we're post-fork and these will be discarded? + cleaned_toplevel = {k: _deprecate_top_level_fact(v) for k, v in clean_facts(af).items()} + variables.update(cleaned_toplevel) # save the notification target in the result, if it was specified, as # this task may be running in a loop in which case the notification @@ -857,6 +811,50 @@ class TaskExecutor: display.debug("attempt loop complete, returning result") return result + @staticmethod + def _apply_task_result_compat(result: dict[str, t.Any], warning_ctx: _DeferredWarningContext) -> None: + """Apply backward-compatibility mutations to the supplied task result.""" + if warnings := result.get('warnings'): + if isinstance(warnings, list): + for warning in warnings: + if not isinstance(warning, WarningSummary): + # translate non-WarningMessageDetail messages + warning = WarningSummary( + details=( + Detail(msg=str(warning)), + ), + ) + + warning_ctx.capture(warning) + else: + display.warning(f"Task result `warnings` was {type(warnings)} instead of {list}.") + + if deprecations := result.get('deprecations'): + if isinstance(deprecations, list): + for deprecation in deprecations: + if not isinstance(deprecation, DeprecationSummary): + # translate non-DeprecationMessageDetail message dicts + try: + if deprecation.pop('collection_name', ...) is not ...: + # deprecated: description='enable the deprecation message for collection_name' core_version='2.23' + # self.deprecated('The `collection_name` key in the `deprecations` dictionary is deprecated.', version='2.27') + pass + + # DTFIX-RELEASE: when plugin isn't set, do it at the boundary where we receive the module/action results + # that may even allow us to never set it in modules/actions directly and to populate it at the boundary + deprecation = DeprecationSummary( + details=( + Detail(msg=deprecation.pop('msg')), + ), + **deprecation, + ) + except Exception as ex: + display.error_as_warning("Task result `deprecations` contained an invalid item.", exception=ex) + + warning_ctx.capture(deprecation) + else: + display.warning(f"Task result `deprecations` was {type(deprecations)} instead of {list}.") + def _poll_async_result(self, result, templar, task_vars=None): """ Polls for the specified JID to be complete @@ -890,7 +888,7 @@ class TaskExecutor: connection=self._connection, play_context=self._play_context, loader=self._loader, - templar=templar, + templar=Templar._from_template_engine(templar), shared_loader_obj=self._shared_loader_obj, ) @@ -960,7 +958,7 @@ class TaskExecutor: connection=self._connection, play_context=self._play_context, loader=self._loader, - templar=templar, + templar=Templar._from_template_engine(templar), shared_loader_obj=self._shared_loader_obj, ) cleanup_handler.run(task_vars=task_vars) @@ -1057,7 +1055,11 @@ class TaskExecutor: options = {} for k in option_vars: if k in variables: - options[k] = templar.template(variables[k]) + try: + options[k] = templar.template(variables[k]) + except AnsibleValueOmittedError: + pass + # TODO move to task method? plugin.set_options(task_keys=task_keys, var_options=options) @@ -1128,7 +1130,7 @@ class TaskExecutor: """ return self._get_action_handler_with_module_context(templar)[0] - def _get_action_handler_with_module_context(self, templar): + def _get_action_handler_with_module_context(self, templar: TemplateEngine): """ Returns the correct action plugin to handle the requestion task action and the module context """ @@ -1190,7 +1192,7 @@ class TaskExecutor: connection=self._connection, play_context=self._play_context, loader=self._loader, - templar=templar, + templar=Templar._from_template_engine(templar), shared_loader_obj=self._shared_loader_obj, collection_list=collections ) diff --git a/lib/ansible/executor/task_queue_manager.py b/lib/ansible/executor/task_queue_manager.py index ce4a72952ec..3079d3ecc42 100644 --- a/lib/ansible/executor/task_queue_manager.py +++ b/lib/ansible/executor/task_queue_manager.py @@ -27,18 +27,22 @@ import multiprocessing.queues from ansible import constants as C from ansible import context -from ansible.errors import AnsibleError +from ansible.errors import AnsibleError, ExitCode, AnsibleCallbackError +from ansible._internal._errors._handler import ErrorHandler from ansible.executor.play_iterator import PlayIterator from ansible.executor.stats import AggregateStats from ansible.executor.task_result import TaskResult +from ansible.inventory.data import InventoryData from ansible.module_utils.six import string_types -from ansible.module_utils.common.text.converters import to_text, to_native +from ansible.module_utils.common.text.converters import to_native +from ansible.parsing.dataloader import DataLoader from ansible.playbook.play_context import PlayContext from ansible.playbook.task import Task from ansible.plugins.loader import callback_loader, strategy_loader, module_loader from ansible.plugins.callback import CallbackBase -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.vars.hostvars import HostVars +from ansible.vars.manager import VariableManager from ansible.utils.display import Display from ansible.utils.lock import lock_decorator from ansible.utils.multiprocessing import context as multiprocessing_context @@ -125,27 +129,38 @@ class TaskQueueManager: which dispatches the Play's tasks to hosts. """ - RUN_OK = 0 - RUN_ERROR = 1 - RUN_FAILED_HOSTS = 2 - RUN_UNREACHABLE_HOSTS = 4 - RUN_FAILED_BREAK_PLAY = 8 - RUN_UNKNOWN_ERROR = 255 - - def __init__(self, inventory, variable_manager, loader, passwords, stdout_callback=None, run_additional_callbacks=True, run_tree=False, forks=None): - + RUN_OK = ExitCode.SUCCESS + RUN_ERROR = ExitCode.GENERIC_ERROR + RUN_FAILED_HOSTS = ExitCode.HOST_FAILED + RUN_UNREACHABLE_HOSTS = ExitCode.HOST_UNREACHABLE + RUN_FAILED_BREAK_PLAY = 8 # never leaves PlaybookExecutor.run + RUN_UNKNOWN_ERROR = 255 # never leaves PlaybookExecutor.run, intentionally includes the bit value for 8 + + _callback_dispatch_error_handler = ErrorHandler.from_config('_CALLBACK_DISPATCH_ERROR_BEHAVIOR') + + def __init__( + self, + inventory: InventoryData, + variable_manager: VariableManager, + loader: DataLoader, + passwords: dict[str, str | None], + stdout_callback: str | None = None, + run_additional_callbacks: bool = True, + run_tree: bool = False, + forks: int | None = None, + ) -> None: self._inventory = inventory self._variable_manager = variable_manager self._loader = loader self._stats = AggregateStats() self.passwords = passwords - self._stdout_callback = stdout_callback + self._stdout_callback: str | None | CallbackBase = stdout_callback self._run_additional_callbacks = run_additional_callbacks self._run_tree = run_tree self._forks = forks or 5 self._callbacks_loaded = False - self._callback_plugins = [] + self._callback_plugins: list[CallbackBase] = [] self._start_at_done = False # make sure any module paths (if specified) are added to the module_loader @@ -158,8 +173,8 @@ class TaskQueueManager: self._terminated = False # dictionaries to keep track of failed/unreachable hosts - self._failed_hosts = dict() - self._unreachable_hosts = dict() + self._failed_hosts: dict[str, t.Literal[True]] = dict() + self._unreachable_hosts: dict[str, t.Literal[True]] = dict() try: self._final_q = FinalQueue() @@ -291,7 +306,7 @@ class TaskQueueManager: self.load_callbacks() all_vars = self._variable_manager.get_vars(play=play) - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) new_play = play.copy() new_play.post_validate(templar) @@ -394,25 +409,25 @@ class TaskQueueManager: except AttributeError: pass - def clear_failed_hosts(self): + def clear_failed_hosts(self) -> None: self._failed_hosts = dict() - def get_inventory(self): + def get_inventory(self) -> InventoryData: return self._inventory - def get_variable_manager(self): + def get_variable_manager(self) -> VariableManager: return self._variable_manager - def get_loader(self): + def get_loader(self) -> DataLoader: return self._loader def get_workers(self): return self._workers[:] - def terminate(self): + def terminate(self) -> None: self._terminated = True - def has_dead_workers(self): + def has_dead_workers(self) -> bool: # [, # @@ -469,11 +484,8 @@ class TaskQueueManager: continue for method in methods: - try: - method(*new_args, **kwargs) - except Exception as e: - # TODO: add config toggle to make this fatal or not? - display.warning(u"Failure using method (%s) in callback plugin (%s): %s" % (to_text(method_name), to_text(callback_plugin), to_text(e))) - from traceback import format_tb - from sys import exc_info - display.vvv('Callback Exception: \n' + ' '.join(format_tb(exc_info()[2]))) + with self._callback_dispatch_error_handler.handle(AnsibleCallbackError): + try: + method(*new_args, **kwargs) + except Exception as ex: + raise AnsibleCallbackError(f"Callback dispatch {method_name!r} failed for plugin {callback_plugin._load_name!r}.") from ex diff --git a/lib/ansible/executor/task_result.py b/lib/ansible/executor/task_result.py index 06e9af72e3c..986ffd2e494 100644 --- a/lib/ansible/executor/task_result.py +++ b/lib/ansible/executor/task_result.py @@ -4,12 +4,14 @@ from __future__ import annotations +import typing as t + from ansible import constants as C from ansible.parsing.dataloader import DataLoader from ansible.vars.clean import module_response_deepcopy, strip_internal_keys _IGNORE = ('failed', 'skipped') -_PRESERVE = ('attempts', 'changed', 'retries') +_PRESERVE = ('attempts', 'changed', 'retries', '_ansible_no_log') _SUB_PRESERVE = {'_ansible_delegated_vars': ('ansible_host', 'ansible_port', 'ansible_user', 'ansible_connection')} # stuff callbacks need @@ -127,15 +129,15 @@ class TaskResult: if key in self._result[sub]: subset[sub][key] = self._result[sub][key] - if isinstance(self._task.no_log, bool) and self._task.no_log or self._result.get('_ansible_no_log', False): - x = {"censored": "the output has been hidden due to the fact that 'no_log: true' was specified for this result"} + # DTFIX-FUTURE: is checking no_log here redundant now that we use _ansible_no_log everywhere? + if isinstance(self._task.no_log, bool) and self._task.no_log or self._result.get('_ansible_no_log'): + censored_result = censor_result(self._result) - # preserve full - for preserve in _PRESERVE: - if preserve in self._result: - x[preserve] = self._result[preserve] + if results := self._result.get('results'): + # maintain shape for loop results so callback behavior recognizes a loop was performed + censored_result.update(results=[censor_result(item) if item.get('_ansible_no_log') else item for item in results]) - result._result = x + result._result = censored_result elif self._result: result._result = module_response_deepcopy(self._result) @@ -151,3 +153,10 @@ class TaskResult: result._result.update(subset) return result + + +def censor_result(result: dict[str, t.Any]) -> dict[str, t.Any]: + censored_result = {key: value for key in _PRESERVE if (value := result.get(key, ...)) is not ...} + censored_result.update(censored="the output has been hidden due to the fact that 'no_log: true' was specified for this result") + + return censored_result diff --git a/lib/ansible/galaxy/api.py b/lib/ansible/galaxy/api.py index 97a5c218493..eb3ddb51663 100644 --- a/lib/ansible/galaxy/api.py +++ b/lib/ansible/galaxy/api.py @@ -57,13 +57,13 @@ def should_retry_error(exception): if isinstance(exception, GalaxyError) and exception.http_code in RETRY_HTTP_ERROR_CODES: return True - if isinstance(exception, AnsibleError) and (orig_exc := getattr(exception, 'orig_exc', None)): + if isinstance(exception, AnsibleError) and (cause := exception.__cause__): # URLError is often a proxy for an underlying error, handle wrapped exceptions - if isinstance(orig_exc, URLError): - orig_exc = orig_exc.reason + if isinstance(cause, URLError): + cause = cause.reason # Handle common URL related errors - if isinstance(orig_exc, (TimeoutError, BadStatusLine, IncompleteRead)): + if isinstance(cause, (TimeoutError, BadStatusLine, IncompleteRead)): return True return False @@ -408,11 +408,8 @@ class GalaxyAPI: method=method, timeout=self._server_timeout, http_agent=user_agent(), follow_redirects='safe') except HTTPError as e: raise GalaxyError(e, error_context_msg) - except Exception as e: - raise AnsibleError( - "Unknown error when attempting to call Galaxy at '%s': %s" % (url, to_native(e)), - orig_exc=e - ) + except Exception as ex: + raise AnsibleError(f"Unknown error when attempting to call Galaxy at {url!r}.") from ex resp_data = to_text(resp.read(), errors='surrogate_or_strict') try: @@ -471,8 +468,8 @@ class GalaxyAPI: resp = open_url(url, data=args, validate_certs=self.validate_certs, method="POST", http_agent=user_agent(), timeout=self._server_timeout) except HTTPError as e: raise GalaxyError(e, 'Attempting to authenticate to galaxy') - except Exception as e: - raise AnsibleError('Unable to authenticate to galaxy: %s' % to_native(e), orig_exc=e) + except Exception as ex: + raise AnsibleError('Unable to authenticate to galaxy.') from ex data = json.loads(to_text(resp.read(), errors='surrogate_or_strict')) return data diff --git a/lib/ansible/galaxy/collection/concrete_artifact_manager.py b/lib/ansible/galaxy/collection/concrete_artifact_manager.py index fb807766f5c..983e674e3d2 100644 --- a/lib/ansible/galaxy/collection/concrete_artifact_manager.py +++ b/lib/ansible/galaxy/collection/concrete_artifact_manager.py @@ -485,16 +485,13 @@ def _download_file(url, b_path, expected_hash, validate_certs, token=None, timeo display.display("Downloading %s to %s" % (url, to_text(b_tarball_dir))) # NOTE: Galaxy redirects downloads to S3 which rejects the request # NOTE: if an Authorization header is attached so don't redirect it - try: - resp = open_url( - to_native(url, errors='surrogate_or_strict'), - validate_certs=validate_certs, - headers=None if token is None else token.headers(), - unredirected_headers=['Authorization'], http_agent=user_agent(), - timeout=timeout - ) - except Exception as err: - raise AnsibleError(to_native(err), orig_exc=err) + resp = open_url( + to_native(url, errors='surrogate_or_strict'), + validate_certs=validate_certs, + headers=None if token is None else token.headers(), + unredirected_headers=['Authorization'], http_agent=user_agent(), + timeout=timeout + ) with open(b_file_path, 'wb') as download_file: # type: t.BinaryIO actual_hash = _consume_file(resp, write_to=download_file) diff --git a/lib/ansible/galaxy/dependency_resolution/dataclasses.py b/lib/ansible/galaxy/dependency_resolution/dataclasses.py index 6796ad132e4..9877efdfc38 100644 --- a/lib/ansible/galaxy/dependency_resolution/dataclasses.py +++ b/lib/ansible/galaxy/dependency_resolution/dataclasses.py @@ -7,6 +7,7 @@ from __future__ import annotations import os +import pathlib import typing as t from collections import namedtuple @@ -25,6 +26,8 @@ if t.TYPE_CHECKING: '_ComputedReqKindsMixin', ) +import ansible +import ansible.release from ansible.errors import AnsibleError, AnsibleAssertionError from ansible.galaxy.api import GalaxyAPI @@ -39,6 +42,7 @@ _ALLOW_CONCRETE_POINTER_IN_SOURCE = False # NOTE: This is a feature flag _GALAXY_YAML = b'galaxy.yml' _MANIFEST_JSON = b'MANIFEST.json' _SOURCE_METADATA_FILE = b'GALAXY.yml' +_ANSIBLE_PACKAGE_PATH = pathlib.Path(ansible.__file__).parent display = Display() @@ -224,6 +228,13 @@ class _ComputedReqKindsMixin: if dir_path.endswith(to_bytes(os.path.sep)): dir_path = dir_path.rstrip(to_bytes(os.path.sep)) if not _is_collection_dir(dir_path): + dir_pathlib = pathlib.Path(to_text(dir_path)) + + # special handling for bundled collections without manifests, e.g., ansible._protomatter + if dir_pathlib.is_relative_to(_ANSIBLE_PACKAGE_PATH): + req_name = f'{dir_pathlib.parent.name}.{dir_pathlib.name}' + return cls(req_name, ansible.release.__version__, dir_path, 'dir', None) + display.warning( u"Collection at '{path!s}' does not have a {manifest_json!s} " u'file, nor has it {galaxy_yml!s}: cannot detect version.'. diff --git a/lib/ansible/inventory/data.py b/lib/ansible/inventory/data.py index 691ad5bed42..f879baa4016 100644 --- a/lib/ansible/inventory/data.py +++ b/lib/ansible/inventory/data.py @@ -19,64 +19,49 @@ from __future__ import annotations import sys +import typing as t from ansible import constants as C from ansible.errors import AnsibleError from ansible.inventory.group import Group from ansible.inventory.host import Host -from ansible.module_utils.six import string_types from ansible.utils.display import Display from ansible.utils.vars import combine_vars from ansible.utils.path import basedir +from . import helpers # this is left as a module import to facilitate easier unit test patching + + display = Display() -class InventoryData(object): +class InventoryData: """ Holds inventory data (host and group objects). - Using it's methods should guarantee expected relationships and data. + Using its methods should guarantee expected relationships and data. """ - def __init__(self): + def __init__(self) -> None: - self.groups = {} - self.hosts = {} + self.groups: dict[str, Group] = {} + self.hosts: dict[str, Host] = {} # provides 'groups' magic var, host object has group_names - self._groups_dict_cache = {} + self._groups_dict_cache: dict[str, list[str]] = {} # current localhost, implicit or explicit - self.localhost = None + self.localhost: Host | None = None - self.current_source = None - self.processed_sources = [] + self.current_source: str | None = None + self.processed_sources: list[str] = [] # Always create the 'all' and 'ungrouped' groups, for group in ('all', 'ungrouped'): self.add_group(group) - self.add_child('all', 'ungrouped') - def serialize(self): - self._groups_dict_cache = None - data = { - 'groups': self.groups, - 'hosts': self.hosts, - 'local': self.localhost, - 'source': self.current_source, - 'processed_sources': self.processed_sources - } - return data - - def deserialize(self, data): - self._groups_dict_cache = {} - self.hosts = data.get('hosts') - self.groups = data.get('groups') - self.localhost = data.get('local') - self.current_source = data.get('source') - self.processed_sources = data.get('processed_sources') + self.add_child('all', 'ungrouped') - def _create_implicit_localhost(self, pattern): + def _create_implicit_localhost(self, pattern: str) -> Host: if self.localhost: new_host = self.localhost @@ -100,8 +85,8 @@ class InventoryData(object): return new_host - def reconcile_inventory(self): - """ Ensure inventory basic rules, run after updates """ + def reconcile_inventory(self) -> None: + """Ensure inventory basic rules, run after updates.""" display.debug('Reconcile groups and hosts in inventory.') self.current_source = None @@ -125,7 +110,7 @@ class InventoryData(object): if self.groups['ungrouped'] in mygroups: # clear ungrouped of any incorrectly stored by parser - if set(mygroups).difference(set([self.groups['all'], self.groups['ungrouped']])): + if set(mygroups).difference({self.groups['all'], self.groups['ungrouped']}): self.groups['ungrouped'].remove_host(host) elif not host.implicit: @@ -144,8 +129,10 @@ class InventoryData(object): self._groups_dict_cache = {} - def get_host(self, hostname): - """ fetch host object using name deal with implicit localhost """ + def get_host(self, hostname: str) -> Host | None: + """Fetch host object using name deal with implicit localhost.""" + + hostname = helpers.remove_trust(hostname) matching_host = self.hosts.get(hostname, None) @@ -156,19 +143,19 @@ class InventoryData(object): return matching_host - def add_group(self, group): - """ adds a group to inventory if not there already, returns named actually used """ + def add_group(self, group: str) -> str: + """Adds a group to inventory if not there already, returns named actually used.""" if group: - if not isinstance(group, string_types): + if not isinstance(group, str): raise AnsibleError("Invalid group name supplied, expected a string but got %s for %s" % (type(group), group)) if group not in self.groups: g = Group(group) - if g.name not in self.groups: - self.groups[g.name] = g + group = g.name # the group object may have sanitized the group name; use whatever it has + if group not in self.groups: + self.groups[group] = g self._groups_dict_cache = {} display.debug("Added group %s to inventory" % group) - group = g.name else: display.debug("group %s already in inventory" % group) else: @@ -176,22 +163,24 @@ class InventoryData(object): return group - def remove_group(self, group): + def remove_group(self, group: Group) -> None: - if group in self.groups: - del self.groups[group] - display.debug("Removed group %s from inventory" % group) + if group.name in self.groups: + del self.groups[group.name] + display.debug("Removed group %s from inventory" % group.name) self._groups_dict_cache = {} for host in self.hosts: h = self.hosts[host] h.remove_group(group) - def add_host(self, host, group=None, port=None): - """ adds a host to inventory and possibly a group if not there already """ + def add_host(self, host: str, group: str | None = None, port: int | str | None = None) -> str: + """Adds a host to inventory and possibly a group if not there already.""" + + host = helpers.remove_trust(host) if host: - if not isinstance(host, string_types): + if not isinstance(host, str): raise AnsibleError("Invalid host name supplied, expected a string but got %s for %s" % (type(host), host)) # TODO: add to_safe_host_name @@ -211,7 +200,7 @@ class InventoryData(object): else: self.set_variable(host, 'inventory_file', None) self.set_variable(host, 'inventory_dir', None) - display.debug("Added host %s to inventory" % (host)) + display.debug("Added host %s to inventory" % host) # set default localhost from inventory to avoid creating an implicit one. Last localhost defined 'wins'. if host in C.LOCALHOST: @@ -232,7 +221,7 @@ class InventoryData(object): return host - def remove_host(self, host): + def remove_host(self, host: Host) -> None: if host.name in self.hosts: del self.hosts[host.name] @@ -241,8 +230,10 @@ class InventoryData(object): g = self.groups[group] g.remove_host(host) - def set_variable(self, entity, varname, value): - """ sets a variable for an inventory object """ + def set_variable(self, entity: str, varname: str, value: t.Any) -> None: + """Sets a variable for an inventory object.""" + + inv_object: Host | Group if entity in self.groups: inv_object = self.groups[entity] @@ -254,9 +245,8 @@ class InventoryData(object): inv_object.set_variable(varname, value) display.debug('set %s for %s' % (varname, entity)) - def add_child(self, group, child): - """ Add host or group to group """ - added = False + def add_child(self, group: str, child: str) -> bool: + """Add host or group to group.""" if group in self.groups: g = self.groups[group] if child in self.groups: @@ -271,12 +261,12 @@ class InventoryData(object): raise AnsibleError("%s is not a known group" % group) return added - def get_groups_dict(self): + def get_groups_dict(self) -> dict[str, list[str]]: """ We merge a 'magic' var 'groups' with group name keys and hostname list values into every host variable set. Cache for speed. """ if not self._groups_dict_cache: - for (group_name, group) in self.groups.items(): + for group_name, group in self.groups.items(): self._groups_dict_cache[group_name] = [h.name for h in group.get_hosts()] return self._groups_dict_cache diff --git a/lib/ansible/inventory/group.py b/lib/ansible/inventory/group.py index 335f60127c3..c7b7a7af351 100644 --- a/lib/ansible/inventory/group.py +++ b/lib/ansible/inventory/group.py @@ -16,6 +16,8 @@ # along with Ansible. If not, see . from __future__ import annotations +import typing as t + from collections.abc import Mapping, MutableMapping from enum import Enum from itertools import chain @@ -26,8 +28,13 @@ from ansible.module_utils.common.text.converters import to_native, to_text from ansible.utils.display import Display from ansible.utils.vars import combine_vars +from . import helpers # this is left as a module import to facilitate easier unit test patching + display = Display() +if t.TYPE_CHECKING: + from .host import Host + def to_safe_group_name(name, replacer="_", force=False, silent=False): # Converts 'bad' characters in a string to underscores (or provided replacer) so they can be used as Ansible hosts or groups @@ -59,22 +66,23 @@ class InventoryObjectType(Enum): class Group: - """ a group of ansible hosts """ + """A group of ansible hosts.""" base_type = InventoryObjectType.GROUP # __slots__ = [ 'name', 'hosts', 'vars', 'child_groups', 'parent_groups', 'depth', '_hosts_cache' ] - def __init__(self, name=None): + def __init__(self, name: str) -> None: + name = helpers.remove_trust(name) - self.depth = 0 - self.name = to_safe_group_name(name) - self.hosts = [] - self._hosts = None - self.vars = {} - self.child_groups = [] - self.parent_groups = [] - self._hosts_cache = None - self.priority = 1 + self.depth: int = 0 + self.name: str = to_safe_group_name(name) + self.hosts: list[Host] = [] + self._hosts: set[str] | None = None + self.vars: dict[str, t.Any] = {} + self.child_groups: list[Group] = [] + self.parent_groups: list[Group] = [] + self._hosts_cache: list[Host] | None = None + self.priority: int = 1 def __repr__(self): return self.get_name() @@ -82,44 +90,7 @@ class Group: def __str__(self): return self.get_name() - def __getstate__(self): - return self.serialize() - - def __setstate__(self, data): - return self.deserialize(data) - - def serialize(self): - parent_groups = [] - for parent in self.parent_groups: - parent_groups.append(parent.serialize()) - - self._hosts = None - - result = dict( - name=self.name, - vars=self.vars.copy(), - parent_groups=parent_groups, - depth=self.depth, - hosts=self.hosts, - ) - - return result - - def deserialize(self, data): - self.__init__() # used by __setstate__ to deserialize in place # pylint: disable=unnecessary-dunder-call - self.name = data.get('name') - self.vars = data.get('vars', dict()) - self.depth = data.get('depth', 0) - self.hosts = data.get('hosts', []) - self._hosts = None - - parent_groups = data.get('parent_groups', []) - for parent_data in parent_groups: - g = Group() - g.deserialize(parent_data) - self.parent_groups.append(g) - - def _walk_relationship(self, rel, include_self=False, preserve_ordering=False): + def _walk_relationship(self, rel, include_self=False, preserve_ordering=False) -> set[Group] | list[Group]: """ Given `rel` that is an iterable property of Group, consitituting a directed acyclic graph among all groups, @@ -133,12 +104,12 @@ class Group: F Called on F, returns set of (A, B, C, D, E) """ - seen = set([]) + seen: set[Group] = set([]) unprocessed = set(getattr(self, rel)) if include_self: unprocessed.add(self) if preserve_ordering: - ordered = [self] if include_self else [] + ordered: list[Group] = [self] if include_self else [] ordered.extend(getattr(self, rel)) while unprocessed: @@ -158,22 +129,22 @@ class Group: return ordered return seen - def get_ancestors(self): - return self._walk_relationship('parent_groups') + def get_ancestors(self) -> set[Group]: + return t.cast(set, self._walk_relationship('parent_groups')) - def get_descendants(self, **kwargs): + def get_descendants(self, **kwargs) -> set[Group] | list[Group]: return self._walk_relationship('child_groups', **kwargs) @property - def host_names(self): + def host_names(self) -> set[str]: if self._hosts is None: - self._hosts = set(self.hosts) + self._hosts = {h.name for h in self.hosts} return self._hosts - def get_name(self): + def get_name(self) -> str: return self.name - def add_child_group(self, group): + def add_child_group(self, group: Group) -> bool: added = False if self == group: raise Exception("can't add group to itself") @@ -208,7 +179,7 @@ class Group: self.clear_hosts_cache() return added - def _check_children_depth(self): + def _check_children_depth(self) -> None: depth = self.depth start_depth = self.depth # self.depth could change over loop @@ -227,7 +198,7 @@ class Group: if depth - start_depth > len(seen): raise AnsibleError("The group named '%s' has a recursive dependency loop." % to_native(self.name)) - def add_host(self, host): + def add_host(self, host: Host) -> bool: added = False if host.name not in self.host_names: self.hosts.append(host) @@ -237,7 +208,7 @@ class Group: added = True return added - def remove_host(self, host): + def remove_host(self, host: Host) -> bool: removed = False if host.name in self.host_names: self.hosts.remove(host) @@ -247,7 +218,8 @@ class Group: removed = True return removed - def set_variable(self, key, value): + def set_variable(self, key: str, value: t.Any) -> None: + key = helpers.remove_trust(key) if key == 'ansible_group_priority': self.set_priority(int(value)) @@ -257,36 +229,36 @@ class Group: else: self.vars[key] = value - def clear_hosts_cache(self): + def clear_hosts_cache(self) -> None: self._hosts_cache = None for g in self.get_ancestors(): g._hosts_cache = None - def get_hosts(self): + def get_hosts(self) -> list[Host]: if self._hosts_cache is None: self._hosts_cache = self._get_hosts() return self._hosts_cache - def _get_hosts(self): + def _get_hosts(self) -> list[Host]: - hosts = [] - seen = {} + hosts: list[Host] = [] + seen: set[Host] = set() for kid in self.get_descendants(include_self=True, preserve_ordering=True): kid_hosts = kid.hosts for kk in kid_hosts: if kk not in seen: - seen[kk] = 1 + seen.add(kk) if self.name == 'all' and kk.implicit: continue hosts.append(kk) return hosts - def get_vars(self): + def get_vars(self) -> dict[str, t.Any]: return self.vars.copy() - def set_priority(self, priority): + def set_priority(self, priority: int | str) -> None: try: self.priority = int(priority) except TypeError: diff --git a/lib/ansible/inventory/helpers.py b/lib/ansible/inventory/helpers.py index 8293f905266..43baac96c9b 100644 --- a/lib/ansible/inventory/helpers.py +++ b/lib/ansible/inventory/helpers.py @@ -18,6 +18,7 @@ ############################################# from __future__ import annotations +from ansible._internal._datatag._tags import TrustedAsTemplate from ansible.utils.vars import combine_vars @@ -37,3 +38,11 @@ def get_group_vars(groups): results = combine_vars(results, group.get_vars()) return results + + +def remove_trust(value: str) -> str: + """ + Remove trust from strings which should not be trusted. + This exists to centralize the untagging call which facilitate patching it out in unit tests. + """ + return TrustedAsTemplate.untag(value) diff --git a/lib/ansible/inventory/host.py b/lib/ansible/inventory/host.py index fafa9520928..f41cdd71fed 100644 --- a/lib/ansible/inventory/host.py +++ b/lib/ansible/inventory/host.py @@ -17,28 +17,26 @@ from __future__ import annotations +import collections.abc as c +import typing as t + from collections.abc import Mapping, MutableMapping from ansible.inventory.group import Group, InventoryObjectType from ansible.parsing.utils.addresses import patterns -from ansible.utils.vars import combine_vars, get_unique_id +from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name +from . import helpers # this is left as a module import to facilitate easier unit test patching __all__ = ['Host'] class Host: - """ a single ansible host """ + """A single ansible host.""" base_type = InventoryObjectType.HOST # __slots__ = [ 'name', 'vars', 'groups' ] - def __getstate__(self): - return self.serialize() - - def __setstate__(self, data): - return self.deserialize(data) - def __eq__(self, other): if not isinstance(other, Host): return False @@ -56,55 +54,28 @@ class Host: def __repr__(self): return self.get_name() - def serialize(self): - groups = [] - for group in self.groups: - groups.append(group.serialize()) - - return dict( - name=self.name, - vars=self.vars.copy(), - address=self.address, - uuid=self._uuid, - groups=groups, - implicit=self.implicit, - ) - - def deserialize(self, data): - self.__init__(gen_uuid=False) # used by __setstate__ to deserialize in place # pylint: disable=unnecessary-dunder-call - - self.name = data.get('name') - self.vars = data.get('vars', dict()) - self.address = data.get('address', '') - self._uuid = data.get('uuid', None) - self.implicit = data.get('implicit', False) - - groups = data.get('groups', []) - for group_data in groups: - g = Group() - g.deserialize(group_data) - self.groups.append(g) + def __init__(self, name: str, port: int | str | None = None, gen_uuid: bool = True) -> None: + name = helpers.remove_trust(name) - def __init__(self, name=None, port=None, gen_uuid=True): + self.vars: dict[str, t.Any] = {} + self.groups: list[Group] = [] + self._uuid: str | None = None - self.vars = {} - self.groups = [] - self._uuid = None - - self.name = name - self.address = name + self.name: str = name + self.address: str = name if port: self.set_variable('ansible_port', int(port)) if gen_uuid: self._uuid = get_unique_id() - self.implicit = False - def get_name(self): + self.implicit: bool = False + + def get_name(self) -> str: return self.name - def populate_ancestors(self, additions=None): + def populate_ancestors(self, additions: c.Iterable[Group] | None = None) -> None: # populate ancestors if additions is None: for group in self.groups: @@ -114,7 +85,7 @@ class Host: if group not in self.groups: self.groups.append(group) - def add_group(self, group): + def add_group(self, group: Group) -> bool: added = False # populate ancestors first for oldg in group.get_ancestors(): @@ -127,7 +98,7 @@ class Host: added = True return added - def remove_group(self, group): + def remove_group(self, group: Group) -> bool: removed = False if group in self.groups: self.groups.remove(group) @@ -143,18 +114,25 @@ class Host: self.remove_group(oldg) return removed - def set_variable(self, key, value): + def set_variable(self, key: str, value: t.Any) -> None: + key = helpers.remove_trust(key) + + validate_variable_name(key) + if key in self.vars and isinstance(self.vars[key], MutableMapping) and isinstance(value, Mapping): self.vars = combine_vars(self.vars, {key: value}) else: self.vars[key] = value - def get_groups(self): + def get_groups(self) -> list[Group]: return self.groups - def get_magic_vars(self): - results = {} - results['inventory_hostname'] = self.name + def get_magic_vars(self) -> dict[str, t.Any]: + results: dict[str, t.Any] = dict( + inventory_hostname=self.name, + ) + + # FUTURE: these values should be dynamically calculated on access ala the rest of magic vars if patterns['ipv4'].match(self.name) or patterns['ipv6'].match(self.name): results['inventory_hostname_short'] = self.name else: @@ -164,5 +142,5 @@ class Host: return results - def get_vars(self): + def get_vars(self) -> dict[str, t.Any]: return combine_vars(self.vars, self.get_magic_vars()) diff --git a/lib/ansible/inventory/manager.py b/lib/ansible/inventory/manager.py index ba6397f1787..914be9bd305 100644 --- a/lib/ansible/inventory/manager.py +++ b/lib/ansible/inventory/manager.py @@ -19,28 +19,33 @@ from __future__ import annotations import fnmatch +import functools import os -import sys import re import itertools -import traceback +import typing as t from operator import attrgetter from random import shuffle from ansible import constants as C -from ansible.errors import AnsibleError, AnsibleOptionsError, AnsibleParserError +from ansible._internal import _json, _wrapt +from ansible.errors import AnsibleError, AnsibleOptionsError from ansible.inventory.data import InventoryData from ansible.module_utils.six import string_types from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.parsing.utils.addresses import parse_address from ansible.plugins.loader import inventory_loader +from ansible._internal._datatag._tags import Origin from ansible.utils.helpers import deduplicate_list from ansible.utils.path import unfrackpath from ansible.utils.display import Display from ansible.utils.vars import combine_vars from ansible.vars.plugins import get_vars_from_inventory_sources +if t.TYPE_CHECKING: + from ansible.plugins.inventory import BaseInventoryPlugin + display = Display() IGNORED_ALWAYS = [br"^\.", b"^host_vars$", b"^group_vars$", b"^vars_plugins$"] @@ -196,12 +201,12 @@ class InventoryManager(object): def get_host(self, hostname): return self._inventory.get_host(hostname) - def _fetch_inventory_plugins(self): + def _fetch_inventory_plugins(self) -> list[BaseInventoryPlugin]: """ sets up loaded inventory plugins for usage """ display.vvvv('setting up inventory plugins') - plugins = [] + plugins: list[BaseInventoryPlugin] = [] for name in C.INVENTORY_ENABLED: plugin = inventory_loader.get(name) if plugin: @@ -276,7 +281,6 @@ class InventoryManager(object): # try source with each plugin for plugin in self._fetch_inventory_plugins(): - plugin_name = to_text(getattr(plugin, '_load_name', getattr(plugin, '_original_path', ''))) display.debug(u'Attempting to use plugin %s (%s)' % (plugin_name, plugin._original_path)) @@ -287,9 +291,14 @@ class InventoryManager(object): plugin_wants = False if plugin_wants: + # have this tag ready to apply to errors or output; str-ify source since it is often tagged by the CLI + origin = Origin(description=f'') try: - # FIXME in case plugin fails 1/2 way we have partial inventory - plugin.parse(self._inventory, self._loader, source, cache=cache) + inventory_wrapper = _InventoryDataWrapper(self._inventory, target_plugin=plugin, origin=origin) + + # FUTURE: now that we have a wrapper around inventory, we can have it use ChainMaps to preview the in-progress inventory, + # but be able to roll back partial inventory failures by discarding the outermost layer + plugin.parse(inventory_wrapper, self._loader, source, cache=cache) try: plugin.update_cache_if_changed() except AttributeError: @@ -298,14 +307,16 @@ class InventoryManager(object): parsed = True display.vvv('Parsed %s inventory source with %s plugin' % (source, plugin_name)) break - except AnsibleParserError as e: - display.debug('%s was not parsable by %s' % (source, plugin_name)) - tb = ''.join(traceback.format_tb(sys.exc_info()[2])) - failures.append({'src': source, 'plugin': plugin_name, 'exc': e, 'tb': tb}) - except Exception as e: - display.debug('%s failed while attempting to parse %s' % (plugin_name, source)) - tb = ''.join(traceback.format_tb(sys.exc_info()[2])) - failures.append({'src': source, 'plugin': plugin_name, 'exc': AnsibleError(e), 'tb': tb}) + except AnsibleError as ex: + if not ex.obj: + ex.obj = origin + failures.append({'src': source, 'plugin': plugin_name, 'exc': ex}) + except Exception as ex: + try: + # omit line number to prevent contextual display of script or possibly sensitive info + raise AnsibleError(str(ex), obj=origin) from ex + except AnsibleError as ex: + failures.append({'src': source, 'plugin': plugin_name, 'exc': ex}) else: display.vvv("%s declined parsing %s as it did not pass its verify_file() method" % (plugin_name, source)) @@ -319,9 +330,8 @@ class InventoryManager(object): if failures: # only if no plugin processed files should we show errors. for fail in failures: - display.warning(u'\n* Failed to parse %s with %s plugin: %s' % (to_text(fail['src']), fail['plugin'], to_text(fail['exc']))) - if 'tb' in fail: - display.vvv(to_text(fail['tb'])) + # `obj` should always be set + display.error_as_warning(msg=f'Failed to parse inventory with {fail["plugin"]!r} plugin.', exception=fail['exc']) # final error/warning on inventory source failure if C.INVENTORY_ANY_UNPARSED_IS_FAILED: @@ -749,3 +759,36 @@ class InventoryManager(object): self.reconcile_inventory() result_item['changed'] = changed + + +class _InventoryDataWrapper(_wrapt.ObjectProxy): + """ + Proxy wrapper around InventoryData. + Allows `set_variable` calls to automatically apply template trust for plugins that don't know how. + """ + + # declared as class attrs to signal to ObjectProxy that we want them stored on the proxy, not the wrapped value + _target_plugin = None + _default_origin = None + + def __init__(self, referent: InventoryData, target_plugin: BaseInventoryPlugin, origin: Origin) -> None: + super().__init__(referent) + self._target_plugin = target_plugin + # fallback origin to ensure that vars are tagged with at least the file they came from + self._default_origin = origin + + @functools.cached_property + def _inspector(self) -> _json.AnsibleVariableVisitor: + """ + Inventory plugins can delegate to other plugins (e.g. `auto`). + This hack defers sampling the target plugin's `trusted_by_default` attr until `set_variable` is called, typically inside `parse`. + Trust is then optionally applied based on the plugin's declared intent via `trusted_by_default`. + """ + return _json.AnsibleVariableVisitor( + trusted_as_template=self._target_plugin.trusted_by_default, + origin=self._default_origin, + allow_encrypted_string=True, + ) + + def set_variable(self, entity: str, varname: str, value: t.Any) -> None: + self.__wrapped__.set_variable(entity, varname, self._inspector.visit(value)) diff --git a/lib/ansible/module_utils/_internal/__init__.py b/lib/ansible/module_utils/_internal/__init__.py index e69de29bb2d..c771f51dfce 100644 --- a/lib/ansible/module_utils/_internal/__init__.py +++ b/lib/ansible/module_utils/_internal/__init__.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import collections.abc as c + +import typing as t + + +# DTFIX-RELEASE: bikeshed "intermediate" +INTERMEDIATE_MAPPING_TYPES = (c.Mapping,) +""" +Mapping types which are supported for recursion and runtime usage, such as in serialization and templating. +These will be converted to a simple Python `dict` before serialization or storage as a variable. +""" + +INTERMEDIATE_ITERABLE_TYPES = (tuple, set, frozenset, c.Sequence) +""" +Iterable types which are supported for recursion and runtime usage, such as in serialization and templating. +These will be converted to a simple Python `list` before serialization or storage as a variable. +CAUTION: Scalar types which are sequences should be excluded when using this. +""" + +ITERABLE_SCALARS_NOT_TO_ITERATE_FIXME = (str, bytes) +"""Scalars which are also iterable, and should thus be excluded from iterable checks.""" + + +def is_intermediate_mapping(value: object) -> bool: + """Returns `True` if `value` is a type supported for projection to a Python `dict`, otherwise returns `False`.""" + # DTFIX-RELEASE: bikeshed name + return isinstance(value, INTERMEDIATE_MAPPING_TYPES) + + +def is_intermediate_iterable(value: object) -> bool: + """Returns `True` if `value` is a type supported for projection to a Python `list`, otherwise returns `False`.""" + # DTFIX-RELEASE: bikeshed name + return isinstance(value, INTERMEDIATE_ITERABLE_TYPES) and not isinstance(value, ITERABLE_SCALARS_NOT_TO_ITERATE_FIXME) + + +is_controller: bool = False +"""Set to True automatically when this module is imported into an Ansible controller context.""" + + +def get_controller_serialize_map() -> dict[type, t.Callable]: + """ + Called to augment serialization maps. + This implementation is replaced with the one from ansible._internal in controller contexts. + """ + return {} + + +def import_controller_module(_module_name: str, /) -> t.Any: + """ + Called to conditionally import the named module in a controller context, otherwise returns `None`. + This implementation is replaced with the one from ansible._internal in controller contexts. + """ + return None diff --git a/lib/ansible/module_utils/_internal/_ambient_context.py b/lib/ansible/module_utils/_internal/_ambient_context.py new file mode 100644 index 00000000000..96e098ce396 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_ambient_context.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import annotations + +import contextlib +import contextvars + +# deprecated: description='typing.Self exists in Python 3.11+' python_version='3.10' +from ..compat import typing as t + + +class AmbientContextBase: + """ + An abstract base context manager that, once entered, will be accessible via its `current` classmethod to any code in the same + `contextvars` context (e.g. same thread/coroutine), until it is exited. + """ + + __slots__ = ('_contextvar_token',) + + # DTFIX-FUTURE: subclasses need to be able to opt-in to blocking nested contexts of the same type (basically optional per-callstack singleton behavior) + # DTFIX-RELEASE: this class should enforce strict nesting of contexts; overlapping context lifetimes leads to incredibly difficult to + # debug situations with undefined behavior, so it should fail fast. + # DTFIX-RELEASE: make frozen=True dataclass subclasses work (fix the mutability of the contextvar instance) + + _contextvar: t.ClassVar[contextvars.ContextVar] # pylint: disable=declare-non-slot # pylint bug, see https://github.com/pylint-dev/pylint/issues/9950 + _contextvar_token: contextvars.Token + + def __init_subclass__(cls, **kwargs) -> None: + cls._contextvar = contextvars.ContextVar(cls.__name__) + + @classmethod + def when(cls, condition: bool, /, *args, **kwargs) -> t.Self | contextlib.nullcontext: + """Return an instance of the context if `condition` is `True`, otherwise return a `nullcontext` instance.""" + return cls(*args, **kwargs) if condition else contextlib.nullcontext() + + @classmethod + def current(cls, optional: bool = False) -> t.Self | None: + """ + Return the currently active context value for the current thread or coroutine. + Raises ReferenceError if a context is not active, unless `optional` is `True`. + """ + try: + return cls._contextvar.get() + except LookupError: + if optional: + return None + + raise ReferenceError(f"A required {cls.__name__} context is not active.") from None + + def __enter__(self) -> t.Self: + # DTFIX-RELEASE: actively block multiple entry + self._contextvar_token = self.__class__._contextvar.set(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.__class__._contextvar.reset(self._contextvar_token) + del self._contextvar_token diff --git a/lib/ansible/module_utils/_internal/_ansiballz.py b/lib/ansible/module_utils/_internal/_ansiballz.py new file mode 100644 index 00000000000..d728663409e --- /dev/null +++ b/lib/ansible/module_utils/_internal/_ansiballz.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +"""Support code for exclusive use by the AnsiballZ wrapper.""" + +from __future__ import annotations + +import atexit +import dataclasses +import importlib.util +import json +import os +import runpy +import sys +import typing as t + +from . import _errors +from ._plugin_exec_context import PluginExecContext, HasPluginInfo +from .. import basic +from ..common.json import get_module_encoder, Direction +from ..common.messages import PluginInfo + + +def run_module( + *, + json_params: bytes, + profile: str, + plugin_info_dict: dict[str, object], + module_fqn: str, + modlib_path: str, + init_globals: dict[str, t.Any] | None = None, + coverage_config: str | None = None, + coverage_output: str | None = None, +) -> None: # pragma: nocover + """Used internally by the AnsiballZ wrapper to run an Ansible module.""" + try: + _enable_coverage(coverage_config, coverage_output) + _run_module( + json_params=json_params, + profile=profile, + plugin_info_dict=plugin_info_dict, + module_fqn=module_fqn, + modlib_path=modlib_path, + init_globals=init_globals, + ) + except Exception as ex: # not BaseException, since modules are expected to raise SystemExit + _handle_exception(ex, profile) + + +def _enable_coverage(coverage_config: str | None, coverage_output: str | None) -> None: # pragma: nocover + """Bootstrap `coverage` for the current Ansible module invocation.""" + if not coverage_config: + return + + if coverage_output: + # Enable code coverage analysis of the module. + # This feature is for internal testing and may change without notice. + python_version_string = '.'.join(str(v) for v in sys.version_info[:2]) + os.environ['COVERAGE_FILE'] = f'{coverage_output}=python-{python_version_string}=coverage' + + import coverage + + cov = coverage.Coverage(config_file=coverage_config) + + def atexit_coverage(): + cov.stop() + cov.save() + + atexit.register(atexit_coverage) + + cov.start() + else: + # Verify coverage is available without importing it. + # This will detect when a module would fail with coverage enabled with minimal overhead. + if importlib.util.find_spec('coverage') is None: + raise RuntimeError('Could not find the `coverage` Python module.') + + +def _run_module( + *, + json_params: bytes, + profile: str, + plugin_info_dict: dict[str, object], + module_fqn: str, + modlib_path: str, + init_globals: dict[str, t.Any] | None = None, +) -> None: + """Used internally by `_run_module` to run an Ansible module after coverage has been enabled (if applicable).""" + basic._ANSIBLE_ARGS = json_params + basic._ANSIBLE_PROFILE = profile + + init_globals = init_globals or {} + init_globals.update(_module_fqn=module_fqn, _modlib_path=modlib_path) + + with PluginExecContext(_ModulePluginWrapper(PluginInfo._from_dict(plugin_info_dict))): + # Run the module. By importing it as '__main__', it executes as a script. + runpy.run_module(mod_name=module_fqn, init_globals=init_globals, run_name='__main__', alter_sys=True) + + # An Ansible module must print its own results and exit. If execution reaches this point, that did not happen. + raise RuntimeError('New-style module did not handle its own exit.') + + +def _handle_exception(exception: BaseException, profile: str) -> t.NoReturn: + """Handle the given exception.""" + result = dict( + failed=True, + exception=_errors.create_error_summary(exception), + ) + + encoder = get_module_encoder(profile, Direction.MODULE_TO_CONTROLLER) + + print(json.dumps(result, cls=encoder)) # pylint: disable=ansible-bad-function + + sys.exit(1) # pylint: disable=ansible-bad-function + + +@dataclasses.dataclass(frozen=True) +class _ModulePluginWrapper(HasPluginInfo): + """Modules aren't plugin instances; this adapter implements the `HasPluginInfo` protocol to allow `PluginExecContext` infra to work with modules.""" + + plugin: PluginInfo + + @property + def _load_name(self) -> str: + return self.plugin.requested_name + + @property + def ansible_name(self) -> str: + return self.plugin.resolved_name + + @property + def plugin_type(self) -> str: + return self.plugin.type diff --git a/lib/ansible/module_utils/_internal/_concurrent/_daemon_threading.py b/lib/ansible/module_utils/_internal/_concurrent/_daemon_threading.py index 0b32a062fed..3a29b981100 100644 --- a/lib/ansible/module_utils/_internal/_concurrent/_daemon_threading.py +++ b/lib/ansible/module_utils/_internal/_concurrent/_daemon_threading.py @@ -1,4 +1,5 @@ """Proxy stdlib threading module that only supports non-joinable daemon threads.""" + # NB: all new local module attrs are _ prefixed to ensure an identical public attribute surface area to the module we're proxying from __future__ import annotations as _annotations diff --git a/lib/ansible/module_utils/_internal/_dataclass_annotation_patch.py b/lib/ansible/module_utils/_internal/_dataclass_annotation_patch.py new file mode 100644 index 00000000000..1d1f913908c --- /dev/null +++ b/lib/ansible/module_utils/_internal/_dataclass_annotation_patch.py @@ -0,0 +1,64 @@ +"""Patch broken ClassVar support in dataclasses when ClassVar is accessed via a module other than `typing`.""" + +# deprecated: description='verify ClassVar support in dataclasses has been fixed in Python before removing this patching code', python_version='3.12' + +from __future__ import annotations + +import dataclasses +import sys +import typing as t + +# trigger the bug by exposing typing.ClassVar via a module reference that is not `typing` +_ts = sys.modules[__name__] +ClassVar = t.ClassVar + + +def patch_dataclasses_is_type() -> None: + if not _is_patch_needed(): + return # pragma: nocover + + try: + real_is_type = dataclasses._is_type # type: ignore[attr-defined] + except AttributeError: # pragma: nocover + raise RuntimeError("unable to patch broken dataclasses ClassVar support") from None + + # patch dataclasses._is_type - impl from https://github.com/python/cpython/blob/4c6d4f5cb33e48519922d635894eef356faddba2/Lib/dataclasses.py#L709-L765 + def _is_type(annotation, cls, a_module, a_type, is_type_predicate): + match = dataclasses._MODULE_IDENTIFIER_RE.match(annotation) # type: ignore[attr-defined] + if match: + ns = None + module_name = match.group(1) + if not module_name: + # No module name, assume the class's module did + # "from dataclasses import InitVar". + ns = sys.modules.get(cls.__module__).__dict__ + else: + # Look up module_name in the class's module. + module = sys.modules.get(cls.__module__) + if module and module.__dict__.get(module_name): # this is the patched line; removed `is a_module` + ns = sys.modules.get(a_type.__module__).__dict__ + if ns and is_type_predicate(ns.get(match.group(2)), a_module): + return True + return False + + _is_type._orig_impl = real_is_type # type: ignore[attr-defined] # stash this away to allow unit tests to undo the patch + + dataclasses._is_type = _is_type # type: ignore[attr-defined] + + try: + if _is_patch_needed(): + raise RuntimeError("patching had no effect") # pragma: nocover + except Exception as ex: # pragma: nocover + dataclasses._is_type = real_is_type # type: ignore[attr-defined] + raise RuntimeError("dataclasses ClassVar support is still broken after patching") from ex + + +def _is_patch_needed() -> bool: + @dataclasses.dataclass + class CheckClassVar: + # this is the broken case requiring patching: ClassVar dot-referenced from a module that is not `typing` is treated as an instance field + # DTFIX-RELEASE: add link to CPython bug report to-be-filed (or update associated deprecation comments if we don't) + a_classvar: _ts.ClassVar[int] # type: ignore[name-defined] + a_field: int + + return len(dataclasses.fields(CheckClassVar)) != 1 diff --git a/lib/ansible/module_utils/_internal/_dataclass_validation.py b/lib/ansible/module_utils/_internal/_dataclass_validation.py new file mode 100644 index 00000000000..dcd6472347c --- /dev/null +++ b/lib/ansible/module_utils/_internal/_dataclass_validation.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +"""Code generation of __post_init__ methods for efficient dataclass field type checking at runtime.""" + +from __future__ import annotations + +import atexit +import functools +import itertools +import shutil +import tempfile +import types +import typing as t + +_write_generated_code_to_disk = False + +# deprecated: description='types.UnionType is available in Python 3.10' python_version='3.9' +try: + _union_type: type | None = types.UnionType # type: ignore[attr-defined] + _union_types: tuple = (t.Union, types.UnionType) # type: ignore[attr-defined] +except AttributeError: + _union_type = None # type: ignore[assignment] + _union_types = (t.Union,) # type: ignore[assignment] + + +def inject_post_init_validation(cls: type, allow_subclasses=False) -> None: + """Inject a __post_init__ field validation method on the given dataclass. An existing __post_init__ attribute must already exist.""" + # DTFIX-FUTURE: when cls must have a __post_init__, enforcing it as a no-op would be nice, but is tricky on slotted dataclasses due to double-creation + post_validate_name = '_post_validate' + method_name = '__post_init__' + exec_globals: dict[str, t.Any] = {} + known_types: dict[type, str] = {} + lines: list[str] = [] + field_type_hints = t.get_type_hints(cls) + indent = 1 + + def append_line(line: str) -> None: + """Append a line to the generated source at the current indentation level.""" + lines.append((' ' * indent * 4) + line) + + def register_type(target_type: type) -> str: + """Register the target type and return the local name.""" + target_name = f'{target_type.__module__.replace(".", "_")}_{target_type.__name__}' + + known_types[target_type] = target_name + exec_globals[target_name] = target_type + + return target_name + + def validate_value(target_name: str, target_ref: str, target_type: type) -> None: + """Generate code to validate the specified value.""" + nonlocal indent + + origin_type = t.get_origin(target_type) + + if origin_type is t.ClassVar: + return # ignore annotations which are not fields, indicated by the t.ClassVar annotation + + allowed_types = _get_allowed_types(target_type) + + # check value + + if origin_type is t.Literal: + # DTFIX-FUTURE: support optional literals + + values = t.get_args(target_type) + + append_line(f"""if {target_ref} not in {values}:""") + append_line(f""" raise ValueError(rf"{target_name} must be one of {values} instead of {{{target_ref}!r}}")""") + + allowed_refs = [register_type(allowed_type) for allowed_type in allowed_types] + allowed_names = [repr(allowed_type) for allowed_type in allowed_types] + + if allow_subclasses: + if len(allowed_refs) == 1: + append_line(f"""if not isinstance({target_ref}, {allowed_refs[0]}):""") + else: + append_line(f"""if not isinstance({target_ref}, ({', '.join(allowed_refs)})):""") + else: + if len(allowed_refs) == 1: + append_line(f"""if type({target_ref}) is not {allowed_refs[0]}:""") + else: + append_line(f"""if type({target_ref}) not in ({', '.join(allowed_refs)}):""") + + append_line(f""" raise TypeError(f"{target_name} must be {' or '.join(allowed_names)} instead of {{type({target_ref})}}")""") + + # check elements (for containers) + + if target_ref.startswith('self.'): + local_ref = target_ref[5:] + else: + local_ref = target_ref + + if tuple in allowed_types: + tuple_type = _extract_type(target_type, tuple) + + idx_ref = f'{local_ref}_idx' + item_ref = f'{local_ref}_item' + item_name = f'{target_name}[{{{idx_ref}!r}}]' + item_type, _ellipsis = t.get_args(tuple_type) + + if _ellipsis is not ...: + raise ValueError(f"{cls} tuple fields must be a tuple of a single element type") + + append_line(f"""if isinstance({target_ref}, {known_types[tuple]}):""") + append_line(f""" for {idx_ref}, {item_ref} in enumerate({target_ref}):""") + + indent += 2 + validate_value(target_name=item_name, target_ref=item_ref, target_type=item_type) + indent -= 2 + + if list in allowed_types: + list_type = _extract_type(target_type, list) + + idx_ref = f'{local_ref}_idx' + item_ref = f'{local_ref}_item' + item_name = f'{target_name}[{{{idx_ref}!r}}]' + (item_type,) = t.get_args(list_type) + + append_line(f"""if isinstance({target_ref}, {known_types[list]}):""") + append_line(f""" for {idx_ref}, {item_ref} in enumerate({target_ref}):""") + + indent += 2 + validate_value(target_name=item_name, target_ref=item_ref, target_type=item_type) + indent -= 2 + + if dict in allowed_types: + dict_type = _extract_type(target_type, dict) + + key_ref, value_ref = f'{local_ref}_key', f'{local_ref}_value' + key_type, value_type = t.get_args(dict_type) + key_name, value_name = f'{target_name!r} key {{{key_ref}!r}}', f'{target_name}[{{{key_ref}!r}}]' + + append_line(f"""if isinstance({target_ref}, {known_types[dict]}):""") + append_line(f""" for {key_ref}, {value_ref} in {target_ref}.items():""") + + indent += 2 + validate_value(target_name=key_name, target_ref=key_ref, target_type=key_type) + validate_value(target_name=value_name, target_ref=value_ref, target_type=value_type) + indent -= 2 + + for field_name in cls.__annotations__: + validate_value(target_name=f'{{type(self).__name__}}.{field_name}', target_ref=f'self.{field_name}', target_type=field_type_hints[field_name]) + + if hasattr(cls, post_validate_name): + append_line(f"self.{post_validate_name}()") + + if not lines: + return # nothing to validate (empty dataclass) + + if '__init__' in cls.__dict__ and not hasattr(cls, method_name): + raise ValueError(f"{cls} must have a {method_name!r} method to override when invoked after the '__init__' method is created") + + if any(hasattr(parent, method_name) for parent in cls.__mro__[1:]): + lines.insert(0, f' super({register_type(cls)}, self).{method_name}()') + + lines.insert(0, f'def {method_name}(self):') + + source = '\n'.join(lines) + '\n' + + if _write_generated_code_to_disk: + tmp = tempfile.NamedTemporaryFile(mode='w+t', suffix=f'-{cls.__module__}.{cls.__name__}.py', delete=False, dir=_get_temporary_directory()) + + tmp.write(source) + tmp.flush() + + filename = tmp.name + else: + filename = f' generated for {cls}' + + code = compile(source, filename, 'exec') + + exec(code, exec_globals) + setattr(cls, method_name, exec_globals[method_name]) + + +@functools.lru_cache(maxsize=1) +def _get_temporary_directory() -> str: + """Create a temporary directory and return its full path. The directory will be deleted when the process exits.""" + temp_dir = tempfile.mkdtemp() + + atexit.register(lambda: shutil.rmtree(temp_dir)) + + return temp_dir + + +def _get_allowed_types(target_type: type) -> tuple[type, ...]: + """Return a tuple of types usable in instance checks for the given target_type.""" + origin_type = t.get_origin(target_type) + + if origin_type in _union_types: + allowed_types = tuple(set(itertools.chain.from_iterable(_get_allowed_types(arg) for arg in t.get_args(target_type)))) + elif origin_type is t.Literal: + allowed_types = (str,) # DTFIX-FUTURE: support non-str literal types + elif origin_type: + allowed_types = (origin_type,) + else: + allowed_types = (target_type,) + + return allowed_types + + +def _extract_type(target_type: type, of_type: type) -> type: + """Return `of_type` from `target_type`, where `target_type` may be a union.""" + origin_type = t.get_origin(target_type) + + if origin_type is of_type: # pylint: disable=unidiomatic-typecheck + return target_type + + if origin_type is t.Union or (_union_type and isinstance(target_type, _union_type)): + args = t.get_args(target_type) + extracted_types = [arg for arg in args if type(arg) is of_type or t.get_origin(arg) is of_type] # pylint: disable=unidiomatic-typecheck + (extracted_type,) = extracted_types + return extracted_type + + raise NotImplementedError(f'{target_type} is not supported') diff --git a/lib/ansible/module_utils/_internal/_datatag/__init__.py b/lib/ansible/module_utils/_internal/_datatag/__init__.py new file mode 100644 index 00000000000..aa94ad4f4ce --- /dev/null +++ b/lib/ansible/module_utils/_internal/_datatag/__init__.py @@ -0,0 +1,928 @@ +from __future__ import annotations + +import abc +import collections.abc as c +import copy +import dataclasses +import datetime +import inspect +import sys + +from itertools import chain + +# deprecated: description='typing.Self exists in Python 3.11+' python_version='3.10' +from ansible.module_utils.compat import typing as t + +from ansible.module_utils._internal import _dataclass_validation +from ansible.module_utils._internal._patches import _sys_intern_patch, _socket_patch + +_sys_intern_patch.SysInternPatch.patch() +_socket_patch.GetAddrInfoPatch.patch() # DTFIX-FUTURE: consider replacing this with a socket import shim that installs the patch + +if sys.version_info >= (3, 10): + # Using slots for reduced memory usage and improved performance. + _tag_dataclass_kwargs = dict(frozen=True, repr=False, kw_only=True, slots=True) +else: + # deprecated: description='always use dataclass slots and keyword-only args' python_version='3.9' + _tag_dataclass_kwargs = dict(frozen=True, repr=False) + +_T = t.TypeVar('_T') +_TAnsibleSerializable = t.TypeVar('_TAnsibleSerializable', bound='AnsibleSerializable') +_TAnsibleDatatagBase = t.TypeVar('_TAnsibleDatatagBase', bound='AnsibleDatatagBase') +_TAnsibleTaggedObject = t.TypeVar('_TAnsibleTaggedObject', bound='AnsibleTaggedObject') + +_NO_INSTANCE_STORAGE = t.cast(t.Tuple[str], tuple()) +_ANSIBLE_TAGGED_OBJECT_SLOTS = tuple(('_ansible_tags_mapping',)) + +# shared empty frozenset for default values +_empty_frozenset: t.FrozenSet = frozenset() + + +class AnsibleTagHelper: + """Utility methods for working with Ansible data tags.""" + + # DTFIX-RELEASE: bikeshed the name and location of this class, also, related, how much more of it should be exposed as public API? + # it may make sense to move this into another module, but the implementations should remain here (so they can be used without circular imports here) + # if they're in a separate module, is a class even needed, or should they be globals? + # DTFIX-RELEASE: add docstrings to all non-override methods in this class + + @staticmethod + def untag(value: _T, *tag_types: t.Type[AnsibleDatatagBase]) -> _T: + """ + If tags matching any of `tag_types` are present on `value`, return a copy with those tags removed. + If no `tag_types` are specified and the object has tags, return a copy with all tags removed. + Otherwise, the original `value` is returned. + """ + tag_set = AnsibleTagHelper.tags(value) + + if not tag_set: + return value + + if tag_types: + tags_mapping = _AnsibleTagsMapping((type(tag), tag) for tag in tag_set if type(tag) not in tag_types) # pylint: disable=unidiomatic-typecheck + + if len(tags_mapping) == len(tag_set): + return value # if no tags were removed, return the original instance + else: + tags_mapping = None + + if not tags_mapping: + if t.cast(AnsibleTaggedObject, value)._empty_tags_as_native: + return t.cast(AnsibleTaggedObject, value)._native_copy() + + tags_mapping = _EMPTY_INTERNAL_TAGS_MAPPING + + tagged_type = AnsibleTaggedObject._get_tagged_type(type(value)) + + return t.cast(_T, tagged_type._instance_factory(value, tags_mapping)) + + @staticmethod + def tags(value: t.Any) -> t.FrozenSet[AnsibleDatatagBase]: + tags = _try_get_internal_tags_mapping(value) + + if tags is _EMPTY_INTERNAL_TAGS_MAPPING: + return _empty_frozenset + + return frozenset(tags.values()) + + @staticmethod + def tag_types(value: t.Any) -> t.FrozenSet[t.Type[AnsibleDatatagBase]]: + tags = _try_get_internal_tags_mapping(value) + + if tags is _EMPTY_INTERNAL_TAGS_MAPPING: + return _empty_frozenset + + return frozenset(tags) + + @staticmethod + def base_type(type_or_value: t.Any, /) -> type: + """Return the friendly type of the given type or value. If the type is an AnsibleTaggedObject, the native type will be used.""" + if isinstance(type_or_value, type): + the_type = type_or_value + else: + the_type = type(type_or_value) + + if issubclass(the_type, AnsibleTaggedObject): + the_type = type_or_value._native_type + + # DTFIX-RELEASE: provide a way to report the real type for debugging purposes + return the_type + + @staticmethod + def as_native_type(value: _T) -> _T: + """ + Returns an untagged native data type matching the input value, or the original input if the value was not a tagged type. + Containers are not recursively processed. + """ + if isinstance(value, AnsibleTaggedObject): + value = value._native_copy() + + return value + + @staticmethod + @t.overload + def tag_copy(src: t.Any, value: _T) -> _T: ... # pragma: nocover + + @staticmethod + @t.overload + def tag_copy(src: t.Any, value: t.Any, *, value_type: type[_T]) -> _T: ... # pragma: nocover + + @staticmethod + @t.overload + def tag_copy(src: t.Any, value: _T, *, value_type: None = None) -> _T: ... # pragma: nocover + + @staticmethod + def tag_copy(src: t.Any, value: _T, *, value_type: t.Optional[type] = None) -> _T: + """Return a copy of `value`, with tags copied from `src`, overwriting any existing tags of the same types.""" + src_tags = AnsibleTagHelper.tags(src) + value_tags = [(tag, tag._get_tag_to_propagate(src, value, value_type=value_type)) for tag in src_tags] + tags = [tag[1] for tag in value_tags if tag[1] is not None] + tag_types_to_remove = [type(tag[0]) for tag in value_tags if tag[1] is None] + + if tag_types_to_remove: + value = AnsibleTagHelper.untag(value, *tag_types_to_remove) + + return AnsibleTagHelper.tag(value, tags, value_type=value_type) + + @staticmethod + @t.overload + def tag(value: _T, tags: t.Union[AnsibleDatatagBase, t.Iterable[AnsibleDatatagBase]]) -> _T: ... # pragma: nocover + + @staticmethod + @t.overload + def tag(value: t.Any, tags: t.Union[AnsibleDatatagBase, t.Iterable[AnsibleDatatagBase]], *, value_type: type[_T]) -> _T: ... # pragma: nocover + + @staticmethod + @t.overload + def tag(value: _T, tags: t.Union[AnsibleDatatagBase, t.Iterable[AnsibleDatatagBase]], *, value_type: None = None) -> _T: ... # pragma: nocover + + @staticmethod + def tag(value: _T, tags: t.Union[AnsibleDatatagBase, t.Iterable[AnsibleDatatagBase]], *, value_type: t.Optional[type] = None) -> _T: + """ + Return a copy of `value`, with `tags` applied, overwriting any existing tags of the same types. + If `value` is an ignored type, or `tags` is empty, the original `value` will be returned. + If `value` is not taggable, a `NotTaggableError` exception will be raised. + If `value_type` was given, that type will be returned instead. + """ + if value_type is None: + value_type_specified = False + value_type = type(value) + else: + value_type_specified = True + + # if no tags to apply, just return what we got + # NB: this only works because the untaggable types are singletons (and thus direct type comparison works) + if not tags or value_type in _untaggable_types: + if value_type_specified: + return value_type(value) + + return value + + tag_list: list[AnsibleDatatagBase] + + # noinspection PyProtectedMember + if type(tags) in _known_tag_types: + tag_list = [tags] # type: ignore[list-item] + else: + tag_list = list(tags) # type: ignore[arg-type] + + for idx, tag in enumerate(tag_list): + # noinspection PyProtectedMember + if type(tag) not in _known_tag_types: + # noinspection PyProtectedMember + raise TypeError(f'tags[{idx}] of type {type(tag)} is not one of {_known_tag_types}') + + existing_internal_tags_mapping = _try_get_internal_tags_mapping(value) + + if existing_internal_tags_mapping is not _EMPTY_INTERNAL_TAGS_MAPPING: + # include the existing tags first so new tags of the same type will overwrite + tag_list = list(chain(existing_internal_tags_mapping.values(), tag_list)) + + tags_mapping = _AnsibleTagsMapping((type(tag), tag) for tag in tag_list) + tagged_type = AnsibleTaggedObject._get_tagged_type(value_type) + + return t.cast(_T, tagged_type._instance_factory(value, tags_mapping)) + + @staticmethod + def try_tag(value: _T, tags: t.Union[AnsibleDatatagBase, t.Iterable[AnsibleDatatagBase]]) -> _T: + """ + Return a copy of `value`, with `tags` applied, overwriting any existing tags of the same types. + If `value` is not taggable or `tags` is empty, the original `value` will be returned. + """ + try: + return AnsibleTagHelper.tag(value, tags) + except NotTaggableError: + return value + + +class AnsibleSerializable(metaclass=abc.ABCMeta): + __slots__ = _NO_INSTANCE_STORAGE + + _known_type_map: t.ClassVar[t.Dict[str, t.Type['AnsibleSerializable']]] = {} + _TYPE_KEY: t.ClassVar[str] = '__ansible_type' + + _type_key: t.ClassVar[str] + + def __init_subclass__(cls, **kwargs) -> None: + # this is needed to call __init__subclass__ on mixins for derived types + super().__init_subclass__(**kwargs) + + cls._type_key = cls.__name__ + + # DTFIX-FUTURE: is there a better way to exclude non-abstract types which are base classes? + if not inspect.isabstract(cls) and not cls.__name__.endswith('Base') and cls.__name__ != 'AnsibleTaggedObject': + AnsibleSerializable._known_type_map[cls._type_key] = cls + + @classmethod + @abc.abstractmethod + def _from_dict(cls: t.Type[_TAnsibleSerializable], d: t.Dict[str, t.Any]) -> object: + """Return an instance of this type, created from the given dictionary.""" + + @abc.abstractmethod + def _as_dict(self) -> t.Dict[str, t.Any]: + """ + Return a serialized version of this instance as a dictionary. + This operation is *NOT* recursive - the returned dictionary may still include custom types. + It is the responsibility of the caller to handle recursion of the returned dict. + """ + + def _serialize(self) -> t.Dict[str, t.Any]: + value = self._as_dict() + value.update({AnsibleSerializable._TYPE_KEY: self._type_key}) + + return value + + @staticmethod + def _deserialize(data: t.Dict[str, t.Any]) -> object: + """Deserialize an object from the supplied data dict, which will be mutated if it contains a type key.""" + type_name = data.pop(AnsibleSerializable._TYPE_KEY, ...) # common usage assumes `data` is an intermediate dict provided by a deserializer + + if type_name is ...: + return None + + type_value = AnsibleSerializable._known_type_map.get(type_name) + + if not type_value: + raise ValueError(f'An unknown {AnsibleSerializable._TYPE_KEY!r} value {type_name!r} was encountered during deserialization.') + + return type_value._from_dict(data) + + def _repr(self, name: str) -> str: + args = self._as_dict() + arg_string = ', '.join((f'{k}={v!r}' for k, v in args.items())) + return f'{name}({arg_string})' + + +class AnsibleSerializableWrapper(AnsibleSerializable, t.Generic[_T], metaclass=abc.ABCMeta): + __slots__ = ('_value',) + + _wrapped_types: t.ClassVar[dict[type, type[AnsibleSerializable]]] = {} + _wrapped_type: t.ClassVar[type] = type(None) + + def __init__(self, value: _T) -> None: + self._value: _T = value + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + cls._wrapped_type = t.get_args(cls.__orig_bases__[0])[0] + cls._wrapped_types[cls._wrapped_type] = cls + + +class AnsibleSerializableDate(AnsibleSerializableWrapper[datetime.date]): + __slots__ = _NO_INSTANCE_STORAGE + + @classmethod + def _from_dict(cls: t.Type[_TAnsibleSerializable], d: t.Dict[str, t.Any]) -> datetime.date: + return datetime.date.fromisoformat(d['iso8601']) + + def _as_dict(self) -> t.Dict[str, t.Any]: + return dict( + iso8601=self._value.isoformat(), + ) + + +class AnsibleSerializableTime(AnsibleSerializableWrapper[datetime.time]): + __slots__ = _NO_INSTANCE_STORAGE + + @classmethod + def _from_dict(cls: t.Type[_TAnsibleSerializable], d: t.Dict[str, t.Any]) -> datetime.time: + value = datetime.time.fromisoformat(d['iso8601']) + value.replace(fold=d['fold']) + + return value + + def _as_dict(self) -> t.Dict[str, t.Any]: + return dict( + iso8601=self._value.isoformat(), + fold=self._value.fold, + ) + + +class AnsibleSerializableDateTime(AnsibleSerializableWrapper[datetime.datetime]): + __slots__ = _NO_INSTANCE_STORAGE + + @classmethod + def _from_dict(cls: t.Type[_TAnsibleSerializable], d: t.Dict[str, t.Any]) -> datetime.datetime: + value = datetime.datetime.fromisoformat(d['iso8601']) + value.replace(fold=d['fold']) + + return value + + def _as_dict(self) -> t.Dict[str, t.Any]: + return dict( + iso8601=self._value.isoformat(), + fold=self._value.fold, + ) + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class AnsibleSerializableDataclass(AnsibleSerializable, metaclass=abc.ABCMeta): + _validation_allow_subclasses = True + + def _as_dict(self) -> t.Dict[str, t.Any]: + # omit None values when None is the field default + # DTFIX-RELEASE: this implementation means we can never change the default on fields which have None for their default + # other defaults can be changed -- but there's no way to override this behavior either way for other default types + # it's a trip hazard to have the default logic here, rather than per field (or not at all) + # consider either removing the filtering or requiring it to be explicitly set per field using dataclass metadata + fields = ((field, getattr(self, field.name)) for field in dataclasses.fields(self)) + return {field.name: value for field, value in fields if value is not None or field.default is not None} + + @classmethod + def _from_dict(cls, d: t.Dict[str, t.Any]) -> t.Self: + # DTFIX-RELEASE: optimize this to avoid the dataclasses fields metadata and get_origin stuff at runtime + type_hints = t.get_type_hints(cls) + mutated_dict: dict[str, t.Any] | None = None + + for field in dataclasses.fields(cls): + if t.get_origin(type_hints[field.name]) is tuple: # NOTE: only supports bare tuples, not optional or inside a union + if type(field_value := d.get(field.name)) is list: # pylint: disable=unidiomatic-typecheck + if mutated_dict is None: + mutated_dict = d.copy() + + mutated_dict[field.name] = tuple(field_value) + + return cls(**(mutated_dict or d)) + + def __init_subclass__(cls, **kwargs) -> None: + super(AnsibleSerializableDataclass, cls).__init_subclass__(**kwargs) # cannot use super() without arguments when using slots + + _dataclass_validation.inject_post_init_validation(cls, cls._validation_allow_subclasses) # code gen a real __post_init__ method + + +class Tripwire: + """Marker mixin for types that should raise an error when encountered.""" + + __slots__ = _NO_INSTANCE_STORAGE + + def trip(self) -> t.NoReturn: + """Derived types should implement a failure behavior.""" + raise NotImplementedError() + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class AnsibleDatatagBase(AnsibleSerializableDataclass, metaclass=abc.ABCMeta): + """ + Base class for data tagging tag types. + New tag types need to be considered very carefully; e.g.: which serialization/runtime contexts they're allowed in, fallback behavior, propagation. + """ + + _validation_allow_subclasses = False + + def __init_subclass__(cls, **kwargs) -> None: + # NOTE: This method is called twice when the datatag type is a dataclass. + super(AnsibleDatatagBase, cls).__init_subclass__(**kwargs) # cannot use super() without arguments when using slots + + # DTFIX-FUTURE: "freeze" this after module init has completed to discourage custom external tag subclasses + + # DTFIX-FUTURE: is there a better way to exclude non-abstract types which are base classes? + if not inspect.isabstract(cls) and not cls.__name__.endswith('Base'): + existing = _known_tag_type_map.get(cls.__name__) + + if existing: + # When the datatag type is a dataclass, the first instance will be the non-dataclass type. + # It must be removed from the known tag types before adding the dataclass version. + _known_tag_types.remove(existing) + + _known_tag_type_map[cls.__name__] = cls + _known_tag_types.add(cls) + + @classmethod + def is_tagged_on(cls, value: t.Any) -> bool: + return cls in _try_get_internal_tags_mapping(value) + + @classmethod + def first_tagged_on(cls, *values: t.Any) -> t.Any | None: + """Return the first value which is tagged with this type, or None if no match is found.""" + for value in values: + if cls.is_tagged_on(value): + return value + + return None + + @classmethod + def get_tag(cls, value: t.Any) -> t.Optional[t.Self]: + return _try_get_internal_tags_mapping(value).get(cls) + + @classmethod + def get_required_tag(cls, value: t.Any) -> t.Self: + if (tag := cls.get_tag(value)) is None: + # DTFIX-FUTURE: we really should have a way to use AnsibleError with obj in module_utils when it's controller-side + raise ValueError(f'The type {type(value).__name__!r} is not tagged with {cls.__name__!r}.') + + return tag + + @classmethod + def untag(cls, value: _T) -> _T: + """ + If this tag type is present on `value`, return a copy with that tag removed. + Otherwise, the original `value` is returned. + """ + return AnsibleTagHelper.untag(value, cls) + + def tag(self, value: _T) -> _T: + """ + Return a copy of `value` with this tag applied, overwriting any existing tag of the same type. + If `value` is an ignored type, the original `value` will be returned. + If `value` is not taggable, a `NotTaggableError` exception will be raised. + """ + return AnsibleTagHelper.tag(value, self) + + def try_tag(self, value: _T) -> _T: + """ + Return a copy of `value` with this tag applied, overwriting any existing tag of the same type. + If `value` is not taggable, the original `value` will be returned. + """ + return AnsibleTagHelper.try_tag(value, self) + + def _get_tag_to_propagate(self, src: t.Any, value: object, *, value_type: t.Optional[type] = None) -> t.Self | None: + """ + Called by `AnsibleTagHelper.tag_copy` during tag propagation. + Returns an instance of this tag appropriate for propagation to `value`, or `None` if the tag should not be propagated. + Derived implementations may consult the arguments relayed from `tag_copy` to determine if and how the tag should be propagated. + """ + return self + + def __repr__(self) -> str: + return AnsibleSerializable._repr(self, self.__class__.__name__) + + +# used by the datatag Ansible/Jinja test plugin to find tags by name +_known_tag_type_map: t.Dict[str, t.Type[AnsibleDatatagBase]] = {} +_known_tag_types: t.Set[t.Type[AnsibleDatatagBase]] = set() + +if sys.version_info >= (3, 9): + # Include the key and value types in the type hints on Python 3.9 and later. + # Earlier versions do not support subscriptable dict. + # deprecated: description='always use subscriptable dict' python_version='3.8' + class _AnsibleTagsMapping(dict[type[_TAnsibleDatatagBase], _TAnsibleDatatagBase]): + __slots__ = _NO_INSTANCE_STORAGE + +else: + + class _AnsibleTagsMapping(dict): + __slots__ = _NO_INSTANCE_STORAGE + + +class _EmptyROInternalTagsMapping(dict): + """ + Optimizes empty tag mapping by using a shared singleton read-only dict. + Since mappingproxy is not pickle-able and causes other problems, we had to roll our own. + """ + + def __new__(cls): + try: + # noinspection PyUnresolvedReferences + return cls._instance + except AttributeError: + cls._instance = dict.__new__(cls) + + # noinspection PyUnresolvedReferences + return cls._instance + + def __setitem__(self, key, value): + raise NotImplementedError() + + def setdefault(self, __key, __default=None): + raise NotImplementedError() + + def update(self, __m, **kwargs): + raise NotImplementedError() + + +_EMPTY_INTERNAL_TAGS_MAPPING = t.cast(_AnsibleTagsMapping, _EmptyROInternalTagsMapping()) +""" +An empty read-only mapping of tags. +Also used as a sentinel to cheaply determine that a type is not tagged by using a reference equality check. +""" + + +class CollectionWithMro(c.Collection, t.Protocol): + """Used to represent a Collection with __mro__ in a TypeGuard for tools that don't include __mro__ in Collection.""" + + __mro__: tuple[type, ...] + + +# DTFIX-RELEASE: This should probably reside elsewhere. +def is_non_scalar_collection_type(value: type) -> t.TypeGuard[type[CollectionWithMro]]: + """Returns True if the value is a non-scalar collection type, otherwise returns False.""" + return issubclass(value, c.Collection) and not issubclass(value, str) and not issubclass(value, bytes) + + +def _try_get_internal_tags_mapping(value: t.Any) -> _AnsibleTagsMapping: + """Return the internal tag mapping of the given value, or a sentinel value if it is not tagged.""" + # noinspection PyBroadException + try: + # noinspection PyProtectedMember + tags = value._ansible_tags_mapping + except Exception: + # try/except is a cheap way to determine if this is a tagged object without using isinstance + # handling Exception accounts for types that may raise something other than AttributeError + return _EMPTY_INTERNAL_TAGS_MAPPING + + # handle cases where the instance always returns something, such as Marker or MagicMock + if type(tags) is not _AnsibleTagsMapping: # pylint: disable=unidiomatic-typecheck + return _EMPTY_INTERNAL_TAGS_MAPPING + + return tags + + +class NotTaggableError(TypeError): + def __init__(self, value): + super(NotTaggableError, self).__init__('{} is not taggable'.format(value)) + + +@dataclasses.dataclass(**_tag_dataclass_kwargs) +class AnsibleSingletonTagBase(AnsibleDatatagBase): + def __new__(cls): + try: + # noinspection PyUnresolvedReferences + return cls._instance + except AttributeError: + cls._instance = AnsibleDatatagBase.__new__(cls) + + # noinspection PyUnresolvedReferences + return cls._instance + + def _as_dict(self) -> t.Dict[str, t.Any]: + return {} + + +class AnsibleTaggedObject(AnsibleSerializable): + __slots__ = _NO_INSTANCE_STORAGE + + _native_type: t.ClassVar[type] + _item_source: t.ClassVar[t.Optional[t.Callable]] = None + + _tagged_type_map: t.ClassVar[t.Dict[type, t.Type['AnsibleTaggedObject']]] = {} + _tagged_collection_types: t.ClassVar[t.Set[t.Type[c.Collection]]] = set() + _collection_types: t.ClassVar[t.Set[t.Type[c.Collection]]] = set() + + _empty_tags_as_native: t.ClassVar[bool] = True # by default, untag will revert to the native type when no tags remain + _subclasses_native_type: t.ClassVar[bool] = True # by default, tagged types are assumed to subclass the type they augment + + _ansible_tags_mapping: _AnsibleTagsMapping | _EmptyROInternalTagsMapping = _EMPTY_INTERNAL_TAGS_MAPPING + """ + Efficient internal storage of tags, indexed by tag type. + Contains no more than one instance of each tag type. + This is defined as a class attribute to support type hinting and documentation. + It is overwritten with an instance attribute during instance creation. + The instance attribute slot is provided by the derived type. + """ + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + + try: + init_class = cls._init_class # type: ignore[attr-defined] + except AttributeError: + pass + else: + init_class() + + if not cls._subclasses_native_type: + return # NOTE: When not subclassing a native type, the derived type must set cls._native_type itself and cls._empty_tags_as_native to False. + + try: + # Subclasses of tagged types will already have a native type set and won't need to detect it. + # Special types which do not subclass a native type can also have their native type already set. + # Automatic item source selection is only implemented for types that don't set _native_type. + cls._native_type + except AttributeError: + # Direct subclasses of native types won't have cls._native_type set, so detect the native type. + cls._native_type = cls.__bases__[0] + + # Detect the item source if not already set. + if cls._item_source is None and is_non_scalar_collection_type(cls._native_type): + cls._item_source = cls._native_type.__iter__ # type: ignore[attr-defined] + + # Use a collection specific factory for types with item sources. + if cls._item_source: + cls._instance_factory = cls._instance_factory_collection # type: ignore[method-assign] + + new_type_direct_subclass = cls.__mro__[1] + + conflicting_impl = AnsibleTaggedObject._tagged_type_map.get(new_type_direct_subclass) + + if conflicting_impl: + raise TypeError(f'Cannot define type {cls.__name__!r} since {conflicting_impl.__name__!r} already extends {new_type_direct_subclass.__name__!r}.') + + AnsibleTaggedObject._tagged_type_map[new_type_direct_subclass] = cls + + if is_non_scalar_collection_type(cls): + AnsibleTaggedObject._tagged_collection_types.add(cls) + AnsibleTaggedObject._collection_types.update({cls, new_type_direct_subclass}) + + def _native_copy(self) -> t.Any: + """ + Returns a copy of the current instance as its native Python type. + Any dynamic access behaviors that apply to this instance will be used during creation of the copy. + In the case of a container type, this is a shallow copy. + Recursive calls to native_copy are the responsibility of the caller. + """ + return self._native_type(self) # pylint: disable=abstract-class-instantiated + + @classmethod + def _instance_factory(cls, value: t.Any, tags_mapping: _AnsibleTagsMapping) -> t.Self: + # There's no way to indicate cls is callable with a single arg without defining a useless __init__. + instance = cls(value) # type: ignore[call-arg] + instance._ansible_tags_mapping = tags_mapping + + return instance + + @staticmethod + def _get_tagged_type(value_type: type) -> type[AnsibleTaggedObject]: + tagged_type: t.Optional[type[AnsibleTaggedObject]] + + if issubclass(value_type, AnsibleTaggedObject): + tagged_type = value_type + else: + tagged_type = AnsibleTaggedObject._tagged_type_map.get(value_type) + + if not tagged_type: + raise NotTaggableError(value_type) + + return tagged_type + + def _as_dict(self) -> t.Dict[str, t.Any]: + return dict( + value=self._native_copy(), + tags=list(self._ansible_tags_mapping.values()), + ) + + @classmethod + def _from_dict(cls: t.Type[_TAnsibleTaggedObject], d: t.Dict[str, t.Any]) -> _TAnsibleTaggedObject: + return AnsibleTagHelper.tag(**d) + + @classmethod + def _instance_factory_collection( + cls, + value: t.Any, + tags_mapping: _AnsibleTagsMapping, + ) -> t.Self: + if type(value) in AnsibleTaggedObject._collection_types: + # use the underlying iterator to avoid access/iteration side effects (e.g. templating/wrapping on Lazy subclasses) + instance = cls(cls._item_source(value)) # type: ignore[call-arg,misc] + else: + # this is used when the value is a generator + instance = cls(value) # type: ignore[call-arg] + + instance._ansible_tags_mapping = tags_mapping + + return instance + + def _copy_collection(self) -> AnsibleTaggedObject: + """ + Return a shallow copy of this instance, which must be a collection. + This uses the underlying iterator to avoid access/iteration side effects (e.g. templating/wrapping on Lazy subclasses). + """ + return AnsibleTagHelper.tag_copy(self, type(self)._item_source(self), value_type=type(self)) # type: ignore[misc] + + @classmethod + def _new(cls, value: t.Any, *args, **kwargs) -> t.Self: + if type(value) is _AnsibleTagsMapping: # pylint: disable=unidiomatic-typecheck + self = cls._native_type.__new__(cls, *args, **kwargs) + self._ansible_tags_mapping = value + return self + + return cls._native_type.__new__(cls, value, *args, **kwargs) + + def _reduce(self, reduced: t.Union[str, tuple[t.Any, ...]]) -> tuple: + if type(reduced) is not tuple: # pylint: disable=unidiomatic-typecheck + raise TypeError() + + updated: list[t.Any] = list(reduced) + updated[1] = (self._ansible_tags_mapping,) + updated[1] + + return tuple(updated) + + +class _AnsibleTaggedStr(str, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + +class _AnsibleTaggedBytes(bytes, AnsibleTaggedObject): + # nonempty __slots__ not supported for subtype of 'bytes' + pass + + +class _AnsibleTaggedInt(int, AnsibleTaggedObject): + # nonempty __slots__ not supported for subtype of 'int' + pass + + +class _AnsibleTaggedFloat(float, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + +class _AnsibleTaggedDateTime(datetime.datetime, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + @classmethod + def _instance_factory(cls, value: datetime.datetime, tags_mapping: _AnsibleTagsMapping) -> _AnsibleTaggedDateTime: + instance = cls( + year=value.year, + month=value.month, + day=value.day, + hour=value.hour, + minute=value.minute, + second=value.second, + microsecond=value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + + instance._ansible_tags_mapping = tags_mapping + + return instance + + def _native_copy(self) -> datetime.datetime: + return datetime.datetime( + year=self.year, + month=self.month, + day=self.day, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond, + tzinfo=self.tzinfo, + fold=self.fold, + ) + + def __new__(cls, year, *args, **kwargs): + return super()._new(year, *args, **kwargs) + + def __reduce_ex__(self, protocol: t.SupportsIndex) -> tuple: + return super()._reduce(super().__reduce_ex__(protocol)) + + def __repr__(self) -> str: + return self._native_copy().__repr__() + + +class _AnsibleTaggedDate(datetime.date, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + @classmethod + def _instance_factory(cls, value: datetime.date, tags_mapping: _AnsibleTagsMapping) -> _AnsibleTaggedDate: + instance = cls( + year=value.year, + month=value.month, + day=value.day, + ) + + instance._ansible_tags_mapping = tags_mapping + + return instance + + def _native_copy(self) -> datetime.date: + return datetime.date( + year=self.year, + month=self.month, + day=self.day, + ) + + def __new__(cls, year, *args, **kwargs): + return super()._new(year, *args, **kwargs) + + def __reduce__(self) -> tuple: + return super()._reduce(super().__reduce__()) + + def __repr__(self) -> str: + return self._native_copy().__repr__() + + +class _AnsibleTaggedTime(datetime.time, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + @classmethod + def _instance_factory(cls, value: datetime.time, tags_mapping: _AnsibleTagsMapping) -> _AnsibleTaggedTime: + instance = cls( + hour=value.hour, + minute=value.minute, + second=value.second, + microsecond=value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + + instance._ansible_tags_mapping = tags_mapping + + return instance + + def _native_copy(self) -> datetime.time: + return datetime.time( + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond, + tzinfo=self.tzinfo, + fold=self.fold, + ) + + def __new__(cls, hour, *args, **kwargs): + return super()._new(hour, *args, **kwargs) + + def __reduce_ex__(self, protocol: t.SupportsIndex) -> tuple: + return super()._reduce(super().__reduce_ex__(protocol)) + + def __repr__(self) -> str: + return self._native_copy().__repr__() + + +class _AnsibleTaggedDict(dict, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + _item_source: t.ClassVar[t.Optional[t.Callable]] = dict.items + + def __copy__(self): + return super()._copy_collection() + + def copy(self) -> _AnsibleTaggedDict: + return copy.copy(self) + + # NB: Tags are intentionally not preserved for operator methods that return a new instance. In-place operators ignore tags from the `other` instance. + # Propagation of tags in these cases is left to the caller, based on needs specific to their use case. + + +class _AnsibleTaggedList(list, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + def __copy__(self): + return super()._copy_collection() + + def copy(self) -> _AnsibleTaggedList: + return copy.copy(self) + + # NB: Tags are intentionally not preserved for operator methods that return a new instance. In-place operators ignore tags from the `other` instance. + # Propagation of tags in these cases is left to the caller, based on needs specific to their use case. + + +# DTFIX-RELEASE: do we want frozenset too? +class _AnsibleTaggedSet(set, AnsibleTaggedObject): + __slots__ = _ANSIBLE_TAGGED_OBJECT_SLOTS + + def __copy__(self): + return super()._copy_collection() + + def copy(self): + return copy.copy(self) + + def __init__(self, value=None, *args, **kwargs): + if type(value) is _AnsibleTagsMapping: # pylint: disable=unidiomatic-typecheck + super().__init__(*args, **kwargs) + else: + super().__init__(value, *args, **kwargs) + + def __new__(cls, value=None, *args, **kwargs): + return super()._new(value, *args, **kwargs) + + def __reduce_ex__(self, protocol: t.SupportsIndex) -> tuple: + return super()._reduce(super().__reduce_ex__(protocol)) + + def __str__(self) -> str: + return self._native_copy().__str__() + + def __repr__(self) -> str: + return self._native_copy().__repr__() + + +class _AnsibleTaggedTuple(tuple, AnsibleTaggedObject): + # nonempty __slots__ not supported for subtype of 'tuple' + + def __copy__(self): + return super()._copy_collection() + + +# This set gets augmented with additional types when some controller-only types are imported. +# While we could proxy or subclass builtin singletons, they're idiomatically compared with "is" reference +# equality, which we can't customize. +_untaggable_types = {type(None), bool} + +# noinspection PyProtectedMember +_ANSIBLE_ALLOWED_VAR_TYPES = frozenset({type(None), bool}) | set(AnsibleTaggedObject._tagged_type_map) | set(AnsibleTaggedObject._tagged_type_map.values()) +"""These are the only types supported by Ansible's variable storage. Subclasses are not permitted.""" + +_ANSIBLE_ALLOWED_NON_SCALAR_COLLECTION_VAR_TYPES = frozenset(item for item in _ANSIBLE_ALLOWED_VAR_TYPES if is_non_scalar_collection_type(item)) +_ANSIBLE_ALLOWED_MAPPING_VAR_TYPES = frozenset(item for item in _ANSIBLE_ALLOWED_VAR_TYPES if issubclass(item, c.Mapping)) +_ANSIBLE_ALLOWED_SCALAR_VAR_TYPES = _ANSIBLE_ALLOWED_VAR_TYPES - _ANSIBLE_ALLOWED_NON_SCALAR_COLLECTION_VAR_TYPES diff --git a/lib/ansible/module_utils/_internal/_datatag/_tags.py b/lib/ansible/module_utils/_internal/_datatag/_tags.py new file mode 100644 index 00000000000..b50e08ee9c3 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_datatag/_tags.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import dataclasses +import datetime +import typing as t + +from ansible.module_utils.common import messages as _messages +from ansible.module_utils._internal import _datatag + + +@dataclasses.dataclass(**_datatag._tag_dataclass_kwargs) +class Deprecated(_datatag.AnsibleDatatagBase): + msg: str + help_text: t.Optional[str] = None + removal_date: t.Optional[datetime.date] = None + removal_version: t.Optional[str] = None + plugin: t.Optional[_messages.PluginInfo] = None + + @classmethod + def _from_dict(cls, d: t.Dict[str, t.Any]) -> Deprecated: + source = d + removal_date = source.get('removal_date') + + if removal_date is not None: + source = source.copy() + source['removal_date'] = datetime.date.fromisoformat(removal_date) + + return cls(**source) + + def _as_dict(self) -> t.Dict[str, t.Any]: + # deprecated: description='no-args super() with slotted dataclass requires 3.14+' python_version='3.13' + # see: https://github.com/python/cpython/pull/124455 + value = super(Deprecated, self)._as_dict() + + if self.removal_date is not None: + value['removal_date'] = self.removal_date.isoformat() + + return value diff --git a/lib/ansible/module_utils/_internal/_debugging.py b/lib/ansible/module_utils/_internal/_debugging.py new file mode 100644 index 00000000000..6fb390ccd62 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_debugging.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import argparse +import pathlib +import sys + + +def load_params() -> tuple[bytes, str]: + """Load module arguments and profile when debugging an Ansible module.""" + parser = argparse.ArgumentParser(description="Directly invoke an Ansible module for debugging.") + parser.add_argument('args', nargs='?', help='module args JSON (file path or inline string)') + parser.add_argument('--profile', default='legacy', help='profile for JSON decoding/encoding of args/response') + + parsed_args = parser.parse_args() + + args: str | None = parsed_args.args + profile: str = parsed_args.profile + + if args: + if (args_path := pathlib.Path(args)).is_file(): + buffer = args_path.read_bytes() + else: + buffer = args.encode(errors='surrogateescape') + else: + if sys.stdin.isatty(): + sys.stderr.write('Waiting for Ansible module JSON on STDIN...\n') + sys.stderr.flush() + + buffer = sys.stdin.buffer.read() + + return buffer, profile diff --git a/lib/ansible/module_utils/_internal/_errors.py b/lib/ansible/module_utils/_internal/_errors.py new file mode 100644 index 00000000000..b6e6d749071 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_errors.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +"""Internal error handling logic for targets. Not for use on the controller.""" + +from __future__ import annotations + +from . import _traceback +from ..common.messages import Detail, ErrorSummary + + +def create_error_summary(exception: BaseException) -> ErrorSummary: + """Return an `ErrorDetail` created from the given exception.""" + return ErrorSummary( + details=_create_error_details(exception), + formatted_traceback=_traceback.maybe_extract_traceback(exception, _traceback.TracebackEvent.ERROR), + ) + + +def _create_error_details(exception: BaseException) -> tuple[Detail, ...]: + """Return an `ErrorMessage` tuple created from the given exception.""" + target_exception: BaseException | None = exception + error_details: list[Detail] = [] + + while target_exception: + error_details.append(Detail(msg=str(target_exception).strip())) + + target_exception = target_exception.__cause__ + + return tuple(error_details) diff --git a/lib/ansible/module_utils/_internal/_json/__init__.py b/lib/ansible/module_utils/_internal/_json/__init__.py new file mode 100644 index 00000000000..d04c7a243e7 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/__init__.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import importlib +import importlib.util +import types + +import typing as t + +from ansible.module_utils._internal._json._profiles import AnsibleProfileJSONEncoder, AnsibleProfileJSONDecoder, _JSONSerializationProfile +from ansible.module_utils import _internal + +_T = t.TypeVar('_T', AnsibleProfileJSONEncoder, AnsibleProfileJSONDecoder) + + +def get_encoder_decoder(profile: str | types.ModuleType, return_type: type[_T]) -> type[_T]: + class_name = 'Encoder' if return_type is AnsibleProfileJSONEncoder else 'Decoder' + + return getattr(get_serialization_module(profile), class_name) + + +def get_module_serialization_profile_name(name: str, controller_to_module: bool) -> str: + if controller_to_module: + name = f'module_{name}_c2m' + else: + name = f'module_{name}_m2c' + + return name + + +def get_module_serialization_profile_module_name(name: str, controller_to_module: bool) -> str: + return get_serialization_module_name(get_module_serialization_profile_name(name, controller_to_module)) + + +def get_serialization_profile(name: str | types.ModuleType) -> _JSONSerializationProfile: + return getattr(get_serialization_module(name), '_Profile') + + +def get_serialization_module(name: str | types.ModuleType) -> types.ModuleType: + return importlib.import_module(get_serialization_module_name(name)) + + +def get_serialization_module_name(name: str | types.ModuleType) -> str: + if isinstance(name, str): + if '.' in name: + return name # name is already fully qualified + + target_name = f'{__name__}._profiles._{name}' + elif isinstance(name, types.ModuleType): + return name.__name__ + else: + raise TypeError(f'Name is {type(name)} instead of {str} or {types.ModuleType}.') + + if importlib.util.find_spec(target_name): + return target_name + + # the value of is_controller can change after import; always pick it up from the module + if _internal.is_controller: + controller_name = f'ansible._internal._json._profiles._{name}' + + if importlib.util.find_spec(controller_name): + return controller_name + + raise ValueError(f'Unknown profile name {name!r}.') diff --git a/lib/ansible/module_utils/_internal/_json/_legacy_encoder.py b/lib/ansible/module_utils/_internal/_json/_legacy_encoder.py new file mode 100644 index 00000000000..2e4e940c708 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_legacy_encoder.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from ansible.module_utils._internal._json import _profiles +from ansible.module_utils._internal._json._profiles import _tagless + + +class LegacyTargetJSONEncoder(_tagless.Encoder): + """Compatibility wrapper over `legacy` profile JSON encoder to support trust stripping and vault value plaintext conversion.""" + + def __init__(self, preprocess_unsafe: bool = False, vault_to_text: bool = False, _decode_bytes: bool = False, **kwargs) -> None: + self._decode_bytes = _decode_bytes + + # NOTE: The preprocess_unsafe and vault_to_text arguments are features of LegacyControllerJSONEncoder. + # They are implemented here to allow callers to pass them without raising an error, but they have no effect. + + super().__init__(**kwargs) + + def default(self, o: object) -> object: + if self._decode_bytes: + if type(o) is _profiles._WrappedValue: # pylint: disable=unidiomatic-typecheck + o = o.wrapped + + if isinstance(o, bytes): + return o.decode(errors='surrogateescape') # backward compatibility with `ansible.module_utils.basic.jsonify` + + return super().default(o) diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/__init__.py b/lib/ansible/module_utils/_internal/_json/_profiles/__init__.py new file mode 100644 index 00000000000..332e60c4bb8 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/__init__.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import datetime +import functools +import json +import typing as t + +from ansible.module_utils import _internal +from ansible.module_utils.common import messages as _messages +from ansible.module_utils._internal._datatag import ( + AnsibleSerializable, + AnsibleSerializableWrapper, + AnsibleTaggedObject, + Tripwire, + _AnsibleTaggedBytes, + _AnsibleTaggedDate, + _AnsibleTaggedDateTime, + _AnsibleTaggedDict, + _AnsibleTaggedFloat, + _AnsibleTaggedInt, + _AnsibleTaggedList, + _AnsibleTaggedSet, + _AnsibleTaggedStr, + _AnsibleTaggedTime, + _AnsibleTaggedTuple, + AnsibleTagHelper, + _tags, +) + +# transformations to "final" JSON representations can only use: +# str, float, int, bool, None, dict, list +# NOT SUPPORTED: tuple, set -- the representation of these in JSON varies by profile (can raise an error, may be converted to list, etc.) +# This means that any special handling required on JSON types that are not wrapped/tagged must be done in a pre-pass before serialization. +# The final type map cannot contain any JSON types other than tuple or set. + + +_NoneType: t.Final[type] = type(None) + +_json_subclassable_scalar_types: t.Final[tuple[type, ...]] = (str, float, int) +"""Scalar types understood by JSONEncoder which can also be subclassed.""" + +_json_scalar_types: t.Final[tuple[type, ...]] = (str, float, int, bool, _NoneType) +"""Scalar types understood by JSONEncoder.""" + +_json_container_types: t.Final[tuple[type, ...]] = (dict, list, tuple) +"""Container types understood by JSONEncoder.""" + +_json_types: t.Final[tuple[type, ...]] = _json_scalar_types + _json_container_types +"""Types understood by JSONEncoder.""" + +_intercept_containers = frozenset( + { + dict, + list, + tuple, + _AnsibleTaggedDict, + _AnsibleTaggedList, + _AnsibleTaggedTuple, + } +) +"""Container types to intercept in support of scalar interception.""" + +_common_module_types: frozenset[type[AnsibleSerializable]] = frozenset( + { + _AnsibleTaggedBytes, + _AnsibleTaggedDate, + _AnsibleTaggedDateTime, + _AnsibleTaggedDict, + _AnsibleTaggedFloat, + _AnsibleTaggedInt, + _AnsibleTaggedList, + _AnsibleTaggedSet, + _AnsibleTaggedStr, + _AnsibleTaggedTime, + _AnsibleTaggedTuple, + } +) +""" +Types that must be supported for all Ansible module serialization profiles. + +For module-to-controller, all types should support full fidelity serialization. +This allows infrastructure and library code to use these features even when a module does not. + +For controller-to-module, type behavior is profile dependent. +""" + +_common_module_response_types: frozenset[type[AnsibleSerializable]] = frozenset( + { + _messages.PluginInfo, + _messages.Detail, + _messages.ErrorSummary, + _messages.WarningSummary, + _messages.DeprecationSummary, + _tags.Deprecated, + } +) +"""Types that must be supported for all Ansible module-to-controller serialization profiles.""" + +_T_encoder = t.TypeVar('_T_encoder', bound="AnsibleProfileJSONEncoder") +_T_decoder = t.TypeVar('_T_decoder', bound="AnsibleProfileJSONDecoder") + + +class _JSONSerializationProfile(t.Generic[_T_encoder, _T_decoder]): + serialize_map: t.ClassVar[dict[type, t.Callable]] + """ + Each concrete non-JSON type must be included in this mapping to support serialization. + Including a JSON type in the mapping allows for overriding or disabling of serialization of that type. + """ + + deserialize_map: t.ClassVar[dict[str, t.Callable]] + """A mapping of type keys to type dispatchers for deserialization.""" + + allowed_ansible_serializable_types: t.ClassVar[frozenset[type[AnsibleSerializable]]] = frozenset() + """Each concrete AnsibleSerialiable derived type must be included in this set to support serialization.""" + + _common_discard_tags: t.ClassVar[dict[type, t.Callable]] + """ + Serialize map for tagged types to have their tags discarded. + This is generated by __init_subclass__ and should not be manually updated. + """ + + _allowed_type_keys: t.ClassVar[frozenset[str]] + """ + The set of type keys allowed during deserialization. + This is generated by __init_subclass__ and should not be manually updated. + """ + + _unwrapped_json_types: t.ClassVar[frozenset[type]] + """ + The set of types that do not need to be wrapped during serialization. + This is generated by __init_subclass__ and should not be manually updated. + """ + + profile_name: t.ClassVar[str] + """ + The user-facing name of the profile, derived from the module name in which the profile resides. + Used to load the profile dynamically at runtime. + This is generated by __init_subclass__ and should not be manually updated. + """ + + encode_strings_as_utf8: t.ClassVar[bool] = False + r""" + When enabled, JSON encoding will result in UTF8 strings being emitted. + Otherwise, non-ASCII strings will be escaped with `\uXXXX` escape sequences.` + """ + + @classmethod + def pre_serialize(cls, encoder: _T_encoder, o: t.Any) -> t.Any: + return o + + @classmethod + def post_deserialize(cls, decoder: _T_decoder, o: t.Any) -> t.Any: + return o + + @classmethod + def cannot_serialize_error(cls, target: t.Any, /) -> t.NoReturn: + raise TypeError(f'Object of type {type(target).__name__!r} is not JSON serializable by the {cls.profile_name!r} profile.') + + @classmethod + def cannot_deserialize_error(cls, target_type_name: str, /) -> t.NoReturn: + raise TypeError(f'Object of type {target_type_name!r} is not JSON deserializable by the {cls.profile_name!r} profile.') + + @classmethod + def unsupported_target_type_error(cls, target_type_name: str, _value: dict) -> t.NoReturn: + cls.cannot_deserialize_error(target_type_name) + + @classmethod + def discard_tags(cls, value: AnsibleTaggedObject) -> object: + return value._native_copy() + + @classmethod + def deserialize_serializable(cls, value: dict[str, t.Any]) -> object: + type_key = value[AnsibleSerializable._TYPE_KEY] + + if type_key not in cls._allowed_type_keys: + cls.cannot_deserialize_error(type_key) + + return AnsibleSerializable._deserialize(value) + + @classmethod + def serialize_as_list(cls, value: t.Iterable) -> list: + # DTFIX-FUTURE: once we have separate control/data channels for module-to-controller (and back), warn about this conversion + return AnsibleTagHelper.tag_copy(value, (item for item in value), value_type=list) + + @classmethod + def serialize_as_isoformat(cls, value: datetime.date | datetime.time | datetime.datetime) -> str: + return value.isoformat() + + @classmethod + def serialize_serializable_object(cls, value: AnsibleSerializable) -> t.Any: + return value._serialize() + + @classmethod + def post_init(cls) -> None: + pass + + @classmethod + def maybe_wrap(cls, o: t.Any) -> t.Any: + if type(o) in cls._unwrapped_json_types: + return o + + return _WrappedValue(o) + + @classmethod + def handle_key(cls, k: t.Any) -> t.Any: + if not isinstance(k, str): # DTFIX-FUTURE: optimize this to use all known str-derived types in type map / allowed types + raise TypeError(f'Key of type {type(k).__name__!r} is not JSON serializable by the {cls.profile_name!r} profile.') + + return k + + @classmethod + def default(cls, o: t.Any) -> t.Any: + # Preserve the built-in JSON encoder support for subclasses of scalar types. + + if isinstance(o, _json_subclassable_scalar_types): + return o + + # Preserve the built-in JSON encoder support for subclasses of dict and list. + # Additionally, add universal support for mappings and sequences/sets by converting them to dict and list, respectively. + + if _internal.is_intermediate_mapping(o): + return {cls.handle_key(k): cls.maybe_wrap(v) for k, v in o.items()} + + if _internal.is_intermediate_iterable(o): + return [cls.maybe_wrap(v) for v in o] + + return cls.last_chance(o) + + @classmethod + def last_chance(cls, o: t.Any) -> t.Any: + if isinstance(o, Tripwire): + o.trip() + + cls.cannot_serialize_error(o) + + def __init_subclass__(cls, **kwargs) -> None: + cls.deserialize_map = {} + cls._common_discard_tags = {obj: cls.discard_tags for obj in _common_module_types if issubclass(obj, AnsibleTaggedObject)} + + cls.post_init() + + cls.profile_name = cls.__module__.rsplit('.', maxsplit=1)[-1].lstrip('_') + + wrapper_types = set(obj for obj in cls.serialize_map.values() if isinstance(obj, type) and issubclass(obj, AnsibleSerializableWrapper)) + + cls.allowed_ansible_serializable_types |= wrapper_types + + # no current need to preserve tags on controller-only types or custom behavior for anything in `allowed_serializable_types` + cls.serialize_map.update({obj: cls.serialize_serializable_object for obj in cls.allowed_ansible_serializable_types}) + cls.serialize_map.update({obj: func for obj, func in _internal.get_controller_serialize_map().items() if obj not in cls.serialize_map}) + + cls.deserialize_map[AnsibleSerializable._TYPE_KEY] = cls.deserialize_serializable # always recognize tagged types + + cls._allowed_type_keys = frozenset(obj._type_key for obj in cls.allowed_ansible_serializable_types) + + cls._unwrapped_json_types = frozenset( + {obj for obj in cls.serialize_map if not issubclass(obj, _json_types)} # custom types that do not extend JSON-native types + | {obj for obj in _json_scalar_types if obj not in cls.serialize_map} # JSON-native scalars lacking custom handling + ) + + +class _WrappedValue: + __slots__ = ('wrapped',) + + def __init__(self, wrapped: t.Any) -> None: + self.wrapped = wrapped + + +class AnsibleProfileJSONEncoder(json.JSONEncoder): + """Profile based JSON encoder capable of handling Ansible internal types.""" + + _wrap_container_types = (list, set, tuple, dict) + _profile: type[_JSONSerializationProfile] + + profile_name: str + + def __init__(self, **kwargs): + self._wrap_types = self._wrap_container_types + (AnsibleSerializable,) + + if self._profile.encode_strings_as_utf8: + kwargs.update(ensure_ascii=False) + + super().__init__(**kwargs) + + def __init_subclass__(cls, **kwargs) -> None: + cls.profile_name = cls._profile.profile_name + + def encode(self, o): + o = self._profile.maybe_wrap(self._profile.pre_serialize(self, o)) + + return super().encode(o) + + def default(self, o: t.Any) -> t.Any: + o_type = type(o) + + if o_type is _WrappedValue: # pylint: disable=unidiomatic-typecheck + o = o.wrapped + o_type = type(o) + + if mapped_callable := self._profile.serialize_map.get(o_type): + return self._profile.maybe_wrap(mapped_callable(o)) + + # This is our last chance to intercept the values in containers, so they must be wrapped here. + # Only containers natively understood by the built-in JSONEncoder are recognized, since any other container types must be present in serialize_map. + + if o_type is dict: # pylint: disable=unidiomatic-typecheck + return {self._profile.handle_key(k): self._profile.maybe_wrap(v) for k, v in o.items()} + + if o_type is list or o_type is tuple: # pylint: disable=unidiomatic-typecheck + return [self._profile.maybe_wrap(v) for v in o] # JSONEncoder converts tuple to a list, so just make it a list now + + # Any value here is a type not explicitly handled by this encoder. + # The profile default handler is responsible for generating an error or converting the value to a supported type. + + return self._profile.default(o) + + +class AnsibleProfileJSONDecoder(json.JSONDecoder): + """Profile based JSON decoder capable of handling Ansible internal types.""" + + _profile: type[_JSONSerializationProfile] + + profile_name: str + + def __init__(self, **kwargs): + kwargs.update(object_hook=self.object_hook) + + super().__init__(**kwargs) + + def __init_subclass__(cls, **kwargs) -> None: + cls.profile_name = cls._profile.profile_name + + def raw_decode(self, s: str, idx: int = 0) -> tuple[t.Any, int]: + obj, end = super().raw_decode(s, idx) + + if _string_encoding_check_enabled(): + try: + _recursively_check_string_encoding(obj) + except UnicodeEncodeError as ex: + raise _create_encoding_check_error() from ex + + obj = self._profile.post_deserialize(self, obj) + + return obj, end + + def object_hook(self, pairs: dict[str, object]) -> object: + if _string_encoding_check_enabled(): + try: + for key, value in pairs.items(): + key.encode() + _recursively_check_string_encoding(value) + except UnicodeEncodeError as ex: + raise _create_encoding_check_error() from ex + + for mapped_key, mapped_callable in self._profile.deserialize_map.items(): + if mapped_key in pairs: + return mapped_callable(pairs) + + return pairs + + +_check_encoding_setting = 'MODULE_STRICT_UTF8_RESPONSE' +r""" +The setting to control whether strings are checked to verify they can be encoded as valid UTF8. +This is currently only used during deserialization, to prevent string values from entering the controller which will later fail to be encoded as bytes. + +The encoding failure can occur when the string represents one of two kinds of values: +1) It was created through decoding bytes with the `surrogateescape` error handler, and that handler is not being used when encoding. +2) It represents an invalid UTF8 value, such as `"\ud8f3"` in a JSON payload. This cannot be encoded, even using the `surrogateescape` error handler. + +Although this becomes an error during deserialization, there are other opportunities for these values to become strings within Ansible. +Future code changes should further restrict bytes to string conversions to eliminate use of `surrogateescape` where appropriate. +Additional warnings at other boundaries may be needed to give users an opportunity to resolve the issues before they become errors. +""" +# DTFIX-FUTURE: add strict UTF8 string encoding checking to serialization profiles (to match the checks performed during deserialization) +# DTFIX-RELEASE: the surrogateescape note above isn't quite right, for encoding use surrogatepass, which does work +# DTFIX-RELEASE: this config setting should probably be deprecated + + +def _create_encoding_check_error() -> Exception: + """ + Return an AnsibleError for use when a UTF8 string encoding check has failed. + These checks are only performed in the controller context, but since this is module_utils code, dynamic loading of the `errors` module is required. + """ + errors = _internal.import_controller_module('ansible.errors') # bypass AnsiballZ import scanning + + return errors.AnsibleRuntimeError( + message='Refusing to deserialize an invalid UTF8 string value.', + help_text=f'This check can be disabled with the `{_check_encoding_setting}` setting.', + ) + + +@functools.lru_cache +def _string_encoding_check_enabled() -> bool: + """Return True if JSON deserialization should verify strings can be encoded as valid UTF8.""" + if constants := _internal.import_controller_module('ansible.constants'): # bypass AnsiballZ import scanning + return constants.config.get_config_value(_check_encoding_setting) # covers all profile-based deserializers, not just modules + + return False + + +def _recursively_check_string_encoding(value: t.Any) -> None: + """Recursively check the given object to ensure all strings can be encoded as valid UTF8.""" + value_type = type(value) + + if value_type is str: + value.encode() + elif value_type is list: # dict is handled by the JSON deserializer + for item in value: + _recursively_check_string_encoding(item) diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_fallback_to_str.py b/lib/ansible/module_utils/_internal/_json/_profiles/_fallback_to_str.py new file mode 100644 index 00000000000..92b80ca0d31 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_fallback_to_str.py @@ -0,0 +1,73 @@ +""" +Lossy best-effort serialization for Ansible variables; used primarily for callback JSON display. +Any type which is not supported by JSON will be converted to a string. +The string representation of any type that is not native to JSON is subject to change and should not be considered stable. +The decoder provides no special behavior. +""" + +from __future__ import annotations as _annotations + +import datetime as _datetime +import typing as _t + +from json import dumps as _dumps + +from ... import _datatag +from .. import _profiles + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + serialize_map: _t.ClassVar[dict[type, _t.Callable]] + + @classmethod + def post_init(cls) -> None: + cls.serialize_map = { + bytes: cls.serialize_bytes_as_str, + set: cls.serialize_as_list, + tuple: cls.serialize_as_list, + _datetime.date: cls.serialize_as_isoformat, + _datetime.time: cls.serialize_as_isoformat, + _datetime.datetime: cls.serialize_as_isoformat, + _datatag._AnsibleTaggedDate: cls.discard_tags, + _datatag._AnsibleTaggedTime: cls.discard_tags, + _datatag._AnsibleTaggedDateTime: cls.discard_tags, + _datatag._AnsibleTaggedStr: cls.discard_tags, + _datatag._AnsibleTaggedInt: cls.discard_tags, + _datatag._AnsibleTaggedFloat: cls.discard_tags, + _datatag._AnsibleTaggedSet: cls.discard_tags, + _datatag._AnsibleTaggedList: cls.discard_tags, + _datatag._AnsibleTaggedTuple: cls.discard_tags, + _datatag._AnsibleTaggedDict: cls.discard_tags, + _datatag._AnsibleTaggedBytes: cls.discard_tags, + } + + @classmethod + def serialize_bytes_as_str(cls, value: bytes) -> str: + return value.decode(errors='surrogateescape') + + @classmethod + def handle_key(cls, k: _t.Any) -> _t.Any: + while mapped_callable := cls.serialize_map.get(type(k)): + k = mapped_callable(k) + + k = cls.default(k) + + if not isinstance(k, str): + k = _dumps(k, cls=Encoder) + + return k + + @classmethod + def last_chance(cls, o: _t.Any) -> _t.Any: + try: + return str(o) + except Exception as ex: + return str(ex) + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_c2m.py b/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_c2m.py new file mode 100644 index 00000000000..a1ec7699037 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_c2m.py @@ -0,0 +1,31 @@ +"""Legacy wire format for controller to module communication.""" + +from __future__ import annotations as _annotations + +import datetime as _datetime + +from .. import _profiles + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + @classmethod + def post_init(cls) -> None: + cls.serialize_map = {} + cls.serialize_map.update(cls._common_discard_tags) + cls.serialize_map.update( + { + set: cls.serialize_as_list, # legacy _json_encode_fallback behavior + tuple: cls.serialize_as_list, # JSONEncoder built-in behavior + _datetime.date: cls.serialize_as_isoformat, + _datetime.time: cls.serialize_as_isoformat, # always failed pre-2.18, so okay to include for consistency + _datetime.datetime: cls.serialize_as_isoformat, + } + ) + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_m2c.py b/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_m2c.py new file mode 100644 index 00000000000..78ae0b54992 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_module_legacy_m2c.py @@ -0,0 +1,35 @@ +"""Legacy wire format for module to controller communication.""" + +from __future__ import annotations as _annotations + +import datetime as _datetime + +from .. import _profiles +from ansible.module_utils.common.text.converters import to_text as _to_text + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + @classmethod + def bytes_to_text(cls, value: bytes) -> str: + return _to_text(value, errors='surrogateescape') + + @classmethod + def post_init(cls) -> None: + cls.allowed_ansible_serializable_types = _profiles._common_module_types | _profiles._common_module_response_types + + cls.serialize_map = { + bytes: cls.bytes_to_text, # legacy behavior from jsonify and container_to_text + set: cls.serialize_as_list, # legacy _json_encode_fallback behavior + tuple: cls.serialize_as_list, # JSONEncoder built-in behavior + _datetime.date: cls.serialize_as_isoformat, # legacy parameters.py does this before serialization + _datetime.time: cls.serialize_as_isoformat, # always failed pre-2.18, so okay to include for consistency + _datetime.datetime: cls.serialize_as_isoformat, # legacy _json_encode_fallback behavior *and* legacy parameters.py does this before serialization + } + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_c2m.py b/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_c2m.py new file mode 100644 index 00000000000..a1806b37c0b --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_c2m.py @@ -0,0 +1,35 @@ +"""Data tagging aware wire format for controller to module communication.""" + +from __future__ import annotations as _annotations + +import datetime as _datetime + +from ... import _datatag +from .. import _profiles + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + encode_strings_as_utf8 = True + + @classmethod + def post_init(cls) -> None: + cls.serialize_map = {} + cls.serialize_map.update(cls._common_discard_tags) + cls.serialize_map.update( + { + # The bytes type is not supported, use str instead (future module profiles may support a bytes wrapper distinct from `bytes`). + set: cls.serialize_as_list, # legacy _json_encode_fallback behavior + tuple: cls.serialize_as_list, # JSONEncoder built-in behavior + _datetime.date: _datatag.AnsibleSerializableDate, + _datetime.time: _datatag.AnsibleSerializableTime, + _datetime.datetime: _datatag.AnsibleSerializableDateTime, + } + ) + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_m2c.py b/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_m2c.py new file mode 100644 index 00000000000..a32d2c122b9 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_module_modern_m2c.py @@ -0,0 +1,33 @@ +"""Data tagging aware wire format for module to controller communication.""" + +from __future__ import annotations as _annotations + +import datetime as _datetime + +from ... import _datatag +from .. import _profiles + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + encode_strings_as_utf8 = True + + @classmethod + def post_init(cls) -> None: + cls.allowed_ansible_serializable_types = _profiles._common_module_types | _profiles._common_module_response_types + + cls.serialize_map = { + # The bytes type is not supported, use str instead (future module profiles may support a bytes wrapper distinct from `bytes`). + set: cls.serialize_as_list, # legacy _json_encode_fallback behavior + tuple: cls.serialize_as_list, # JSONEncoder built-in behavior + _datetime.date: _datatag.AnsibleSerializableDate, + _datetime.time: _datatag.AnsibleSerializableTime, + _datetime.datetime: _datatag.AnsibleSerializableDateTime, + } + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_json/_profiles/_tagless.py b/lib/ansible/module_utils/_internal/_json/_profiles/_tagless.py new file mode 100644 index 00000000000..504049d78e8 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_json/_profiles/_tagless.py @@ -0,0 +1,50 @@ +""" +Lossy best-effort serialization for Ansible variables. +Default profile for the `to_json` filter. +Deserialization behavior is identical to JSONDecoder, except known Ansible custom serialization markers will raise an error. +""" + +from __future__ import annotations as _annotations + +import datetime as _datetime +import functools as _functools + +from ... import _datatag +from .. import _profiles + + +class _Profile(_profiles._JSONSerializationProfile["Encoder", "Decoder"]): + @classmethod + def post_init(cls) -> None: + cls.serialize_map = { + # DTFIX-RELEASE: support serialization of every type that is supported in the Ansible variable type system + set: cls.serialize_as_list, + tuple: cls.serialize_as_list, + _datetime.date: cls.serialize_as_isoformat, + _datetime.time: cls.serialize_as_isoformat, + _datetime.datetime: cls.serialize_as_isoformat, + # bytes intentionally omitted as they are not a supported variable type, they were not originally supported by the old AnsibleJSONEncoder + _datatag._AnsibleTaggedDate: cls.discard_tags, + _datatag._AnsibleTaggedTime: cls.discard_tags, + _datatag._AnsibleTaggedDateTime: cls.discard_tags, + _datatag._AnsibleTaggedStr: cls.discard_tags, + _datatag._AnsibleTaggedInt: cls.discard_tags, + _datatag._AnsibleTaggedFloat: cls.discard_tags, + _datatag._AnsibleTaggedSet: cls.discard_tags, + _datatag._AnsibleTaggedList: cls.discard_tags, + _datatag._AnsibleTaggedTuple: cls.discard_tags, + _datatag._AnsibleTaggedDict: cls.discard_tags, + } + + cls.deserialize_map = { + '__ansible_unsafe': _functools.partial(cls.unsupported_target_type_error, '__ansible_unsafe'), + '__ansible_vault': _functools.partial(cls.unsupported_target_type_error, '__ansible_vault'), + } + + +class Encoder(_profiles.AnsibleProfileJSONEncoder): + _profile = _Profile + + +class Decoder(_profiles.AnsibleProfileJSONDecoder): + _profile = _Profile diff --git a/lib/ansible/module_utils/_internal/_patches/__init__.py b/lib/ansible/module_utils/_internal/_patches/__init__.py new file mode 100644 index 00000000000..7e08b04bff3 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_patches/__init__.py @@ -0,0 +1,66 @@ +"""Infrastructure for patching callables with alternative implementations as needed based on patch-specific test criteria.""" + +from __future__ import annotations + +import abc +import typing as t + + +@t.runtime_checkable +class PatchedTarget(t.Protocol): + """Runtime-checkable protocol that allows identification of a patched function via `isinstance`.""" + + unpatched_implementation: t.Callable + + +class CallablePatch(abc.ABC): + """Base class for patches that provides abstractions for validation of broken behavior, installation of patches, and validation of fixed behavior.""" + + target_container: t.ClassVar + """The module object containing the function to be patched.""" + + target_attribute: t.ClassVar[str] + """The attribute name on the target module to patch.""" + + unpatched_implementation: t.ClassVar[t.Callable] + """The unpatched implementation. Available only after the patch has been applied.""" + + @classmethod + @abc.abstractmethod + def is_patch_needed(cls) -> bool: + """Returns True if the patch is currently needed. Returns False if the original target does not need the patch or the patch has already been applied.""" + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> t.Any: + """Invoke the patched or original implementation, depending on whether the patch has been applied or not.""" + + @classmethod + def is_patched(cls) -> bool: + """Returns True if the patch has been applied, otherwise returns False.""" + return isinstance(cls.get_current_implementation(), PatchedTarget) # using a protocol lets us be more resilient to module unload weirdness + + @classmethod + def get_current_implementation(cls) -> t.Any: + """Get the current (possibly patched) implementation from the patch target container.""" + return getattr(cls.target_container, cls.target_attribute) + + @classmethod + def patch(cls) -> None: + """Idempotently apply this patch (if needed).""" + if cls.is_patched(): + return + + cls.unpatched_implementation = cls.get_current_implementation() + + if not cls.is_patch_needed(): + return + + # __call__ requires an instance (otherwise it'll be __new__) + setattr(cls.target_container, cls.target_attribute, cls()) + + if not cls.is_patch_needed(): + return + + setattr(cls.target_container, cls.target_attribute, cls.unpatched_implementation) + + raise RuntimeError(f"Validation of '{cls.target_container.__name__}.{cls.target_attribute}' failed after patching.") diff --git a/lib/ansible/module_utils/_internal/_patches/_dataclass_annotation_patch.py b/lib/ansible/module_utils/_internal/_patches/_dataclass_annotation_patch.py new file mode 100644 index 00000000000..dbb78f7fd75 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_patches/_dataclass_annotation_patch.py @@ -0,0 +1,55 @@ +"""Patches for builtin `dataclasses` module.""" + +# deprecated: description='verify ClassVar support in dataclasses has been fixed in Python before removing this patching code', python_version='3.13' + +from __future__ import annotations + +import dataclasses +import sys +import typing as t + +from . import CallablePatch + +# trigger the bug by exposing typing.ClassVar via a module reference that is not `typing` +_ts = sys.modules[__name__] +ClassVar = t.ClassVar + + +class DataclassesIsTypePatch(CallablePatch): + """Patch broken ClassVar support in dataclasses when ClassVar is accessed via a module other than `typing`.""" + + target_container: t.ClassVar = dataclasses + target_attribute = '_is_type' + + @classmethod + def is_patch_needed(cls) -> bool: + @dataclasses.dataclass + class CheckClassVar: + # this is the broken case requiring patching: ClassVar dot-referenced from a module that is not `typing` is treated as an instance field + # DTFIX-RELEASE: add link to CPython bug report to-be-filed (or update associated deprecation comments if we don't) + a_classvar: _ts.ClassVar[int] # type: ignore[name-defined] + a_field: int + + return len(dataclasses.fields(CheckClassVar)) != 1 + + def __call__(self, annotation, cls, a_module, a_type, is_type_predicate) -> bool: + """ + This is a patched copy of `_is_type` from dataclasses.py in Python 3.13. + It eliminates the redundant source module reference equality check for the ClassVar type that triggers the bug. + """ + match = dataclasses._MODULE_IDENTIFIER_RE.match(annotation) # type: ignore[attr-defined] + if match: + ns = None + module_name = match.group(1) + if not module_name: + # No module name, assume the class's module did + # "from dataclasses import InitVar". + ns = sys.modules.get(cls.__module__).__dict__ + else: + # Look up module_name in the class's module. + module = sys.modules.get(cls.__module__) + if module and module.__dict__.get(module_name): # this is the patched line; removed `is a_module` + ns = sys.modules.get(a_type.__module__).__dict__ + if ns and is_type_predicate(ns.get(match.group(2)), a_module): + return True + return False diff --git a/lib/ansible/module_utils/_internal/_patches/_socket_patch.py b/lib/ansible/module_utils/_internal/_patches/_socket_patch.py new file mode 100644 index 00000000000..fd8c2b16f6d --- /dev/null +++ b/lib/ansible/module_utils/_internal/_patches/_socket_patch.py @@ -0,0 +1,34 @@ +"""Patches for builtin socket module.""" + +from __future__ import annotations + +import contextlib +import socket +import typing as t + +from . import CallablePatch + + +class _CustomInt(int): + """Wrapper around `int` to test if subclasses are accepted.""" + + +class GetAddrInfoPatch(CallablePatch): + """Patch `socket.getaddrinfo` so that its `port` arg works with `int` subclasses.""" + + target_container: t.ClassVar = socket + target_attribute = 'getaddrinfo' + + @classmethod + def is_patch_needed(cls) -> bool: + with contextlib.suppress(OSError): + socket.getaddrinfo('127.0.0.1', _CustomInt(22)) + return False + + return True + + def __call__(self, host, port, *args, **kwargs) -> t.Any: + if type(port) is not int and isinstance(port, int): # pylint: disable=unidiomatic-typecheck + port = int(port) + + return type(self).unpatched_implementation(host, port, *args, **kwargs) diff --git a/lib/ansible/module_utils/_internal/_patches/_sys_intern_patch.py b/lib/ansible/module_utils/_internal/_patches/_sys_intern_patch.py new file mode 100644 index 00000000000..1e785d608e2 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_patches/_sys_intern_patch.py @@ -0,0 +1,34 @@ +"""Patches for the builtin `sys` module.""" + +from __future__ import annotations + +import contextlib +import sys +import typing as t + +from . import CallablePatch + + +class _CustomStr(str): + """Wrapper around `str` to test if subclasses are accepted.""" + + +class SysInternPatch(CallablePatch): + """Patch `sys.intern` so that subclasses of `str` are accepted.""" + + target_container: t.ClassVar = sys + target_attribute = 'intern' + + @classmethod + def is_patch_needed(cls) -> bool: + with contextlib.suppress(TypeError): + sys.intern(_CustomStr("x")) + return False + + return True + + def __call__(self, value: str): + if type(value) is not str and isinstance(value, str): # pylint: disable=unidiomatic-typecheck + value = str(value) + + return type(self).unpatched_implementation(value) diff --git a/lib/ansible/module_utils/_internal/_plugin_exec_context.py b/lib/ansible/module_utils/_internal/_plugin_exec_context.py new file mode 100644 index 00000000000..332badc29c9 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_plugin_exec_context.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import typing as t + +from ._ambient_context import AmbientContextBase +from ..common.messages import PluginInfo + + +class HasPluginInfo(t.Protocol): + """Protocol to type-annotate and expose PluginLoader-set values.""" + + @property + def _load_name(self) -> str: + """The requested name used to load the plugin.""" + + @property + def ansible_name(self) -> str: + """Fully resolved plugin name.""" + + @property + def plugin_type(self) -> str: + """Plugin type name.""" + + +class PluginExecContext(AmbientContextBase): + """Execution context that wraps all plugin invocations to allow infrastructure introspection of the currently-executing plugin instance.""" + + def __init__(self, executing_plugin: HasPluginInfo) -> None: + self._executing_plugin = executing_plugin + + @property + def executing_plugin(self) -> HasPluginInfo: + return self._executing_plugin + + @property + def plugin_info(self) -> PluginInfo: + return PluginInfo( + requested_name=self._executing_plugin._load_name, + resolved_name=self._executing_plugin.ansible_name, + type=self._executing_plugin.plugin_type, + ) + + @classmethod + def get_current_plugin_info(cls) -> PluginInfo | None: + """Utility method to extract a PluginInfo for the currently executing plugin (or None if no plugin is executing).""" + if ctx := cls.current(optional=True): + return ctx.plugin_info + + return None diff --git a/lib/ansible/module_utils/_internal/_testing.py b/lib/ansible/module_utils/_internal/_testing.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/ansible/module_utils/_internal/_traceback.py b/lib/ansible/module_utils/_internal/_traceback.py new file mode 100644 index 00000000000..1e405eff1f8 --- /dev/null +++ b/lib/ansible/module_utils/_internal/_traceback.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +"""Internal utility code for supporting traceback reporting.""" + +from __future__ import annotations + +import enum +import inspect +import traceback + + +class TracebackEvent(enum.Enum): + """The events for which tracebacks can be enabled.""" + + ERROR = enum.auto() + WARNING = enum.auto() + DEPRECATED = enum.auto() + + +def traceback_for() -> list[str]: + """Return a list of traceback event names (not enums) which are enabled.""" + return [value.name.lower() for value in TracebackEvent if is_traceback_enabled(value)] + + +def is_traceback_enabled(event: TracebackEvent) -> bool: + """Return True if tracebacks are enabled for the specified event, otherwise return False.""" + return _is_traceback_enabled(event) + + +def maybe_capture_traceback(event: TracebackEvent) -> str | None: + """ + Optionally capture a traceback for the current call stack, formatted as a string, if the specified traceback event is enabled. + The current and previous frames are omitted to mask the expected call pattern from error/warning handlers. + """ + if not is_traceback_enabled(event): + return None + + tb_lines = [] + + if current_frame := inspect.currentframe(): + # DTFIX-FUTURE: rewrite target-side tracebacks to point at controller-side paths? + frames = inspect.getouterframes(current_frame) + ignore_frame_count = 2 # ignore this function and its caller + tb_lines.append('Traceback (most recent call last):\n') + tb_lines.extend(traceback.format_stack(frames[ignore_frame_count].frame)) + else: + tb_lines.append('Traceback unavailable.\n') + + return ''.join(tb_lines) + + +def maybe_extract_traceback(exception: BaseException, event: TracebackEvent) -> str | None: + """Optionally extract a formatted traceback from the given exception, if the specified traceback event is enabled.""" + + if not is_traceback_enabled(event): + return None + + # deprecated: description='use the single-arg version of format_traceback' python_version='3.9' + tb_lines = traceback.format_exception(type(exception), exception, exception.__traceback__) + + return ''.join(tb_lines) + + +_module_tracebacks_enabled_events: frozenset[TracebackEvent] | None = None +"""Cached enabled TracebackEvent values extracted from `_ansible_tracebacks_for` module arg.""" + + +def _is_module_traceback_enabled(event: TracebackEvent) -> bool: + """Module utility function to lazily load traceback config and determine if traceback collection is enabled for the specified event.""" + global _module_tracebacks_enabled_events + + if _module_tracebacks_enabled_events is None: + try: + # Suboptimal error handling, but since import order can matter, and this is a critical error path, better to fail silently + # than to mask the triggering error by issuing a new error/warning here. + from ..basic import _PARSED_MODULE_ARGS + + _module_tracebacks_enabled_events = frozenset( + TracebackEvent[value.upper()] for value in _PARSED_MODULE_ARGS.get('_ansible_tracebacks_for') + ) # type: ignore[union-attr] + except BaseException: + return True # if things failed early enough that we can't figure this out, assume we want a traceback for troubleshooting + + return event in _module_tracebacks_enabled_events + + +_is_traceback_enabled = _is_module_traceback_enabled +"""Callable to determine if tracebacks are enabled. Overridden on the controller by display. Use `is_traceback_enabled` instead of calling this directly.""" diff --git a/lib/ansible/module_utils/api.py b/lib/ansible/module_utils/api.py index 2415c38a839..f8023824ee3 100644 --- a/lib/ansible/module_utils/api.py +++ b/lib/ansible/module_utils/api.py @@ -31,8 +31,7 @@ import itertools import secrets import sys import time - -import ansible.module_utils.compat.typing as t +import typing as t def rate_limit_argument_spec(spec=None): diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index fbc5ea17630..731f8ded7d1 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -4,6 +4,7 @@ from __future__ import annotations +import copy import json import sys import typing as t @@ -25,6 +26,7 @@ if sys.version_info < _PY_MIN: import __main__ import atexit +import dataclasses as _dataclasses import errno import grp import fcntl @@ -51,6 +53,10 @@ try: except ImportError: HAS_SYSLOG = False +# deprecated: description='types.EllipsisType is available in Python 3.10+' python_version='3.9' +if t.TYPE_CHECKING: + from builtins import ellipsis + try: from systemd import journal, daemon as systemd_daemon # Makes sure that systemd.journal has method sendv() @@ -71,8 +77,12 @@ except ImportError: # Python2 & 3 way to get NoneType NoneType = type(None) -from ._text import to_native, to_bytes, to_text -from ansible.module_utils.common.text.converters import ( +from ._internal import _traceback, _errors, _debugging + +from .common.text.converters import ( + to_native, + to_bytes, + to_text, jsonify, container_to_bytes as json_dict_unicode_to_bytes, container_to_text as json_dict_bytes_to_unicode, @@ -87,6 +97,8 @@ from ansible.module_utils.common.text.formatters import ( SIZE_RANGES, ) +from ansible.module_utils.common import json as _common_json + import hashlib @@ -111,6 +123,8 @@ def _get_available_hash_algorithms(): AVAILABLE_HASH_ALGORITHMS = _get_available_hash_algorithms() +from ansible.module_utils.common import json as _json + from ansible.module_utils.six.moves.collections_abc import ( KeysView, Mapping, MutableMapping, @@ -149,11 +163,12 @@ from ansible.module_utils.common.validation import ( safe_eval, ) from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses +from ansible.module_utils.common import messages as _messages from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean from ansible.module_utils.common.warnings import ( deprecate, - get_deprecation_messages, - get_warning_messages, + get_deprecations, + get_warnings, warn, ) @@ -169,7 +184,9 @@ imap = map # multiple AnsibleModules are created. Otherwise each AnsibleModule would # attempt to read from stdin. Other code should not use this directly as it # is an internal implementation detail -_ANSIBLE_ARGS = None +_ANSIBLE_ARGS: bytes | None = None +_ANSIBLE_PROFILE: str | None = None +_PARSED_MODULE_ARGS: dict[str, t.Any] | None = None FILE_COMMON_ARGUMENTS = dict( @@ -307,40 +324,31 @@ def _load_params(): to call this function and consume its outputs than to implement the logic inside it as a copy in your own code. """ - global _ANSIBLE_ARGS - if _ANSIBLE_ARGS is not None: - buffer = _ANSIBLE_ARGS - else: - # debug overrides to read args from file or cmdline + global _ANSIBLE_ARGS, _ANSIBLE_PROFILE - # Avoid tracebacks when locale is non-utf8 - # We control the args and we pass them as utf8 - if len(sys.argv) > 1: - if os.path.isfile(sys.argv[1]): - with open(sys.argv[1], 'rb') as fd: - buffer = fd.read() - else: - buffer = sys.argv[1].encode('utf-8', errors='surrogateescape') - # default case, read from stdin - else: - buffer = sys.stdin.buffer.read() - _ANSIBLE_ARGS = buffer + if _ANSIBLE_ARGS is None: + _ANSIBLE_ARGS, _ANSIBLE_PROFILE = _debugging.load_params() - try: - params = json.loads(buffer.decode('utf-8')) - except ValueError: - # This helper is used too early for fail_json to work. - print('\n{"msg": "Error: Module unable to decode stdin/parameters as valid JSON. Unable to parse what parameters were passed", "failed": true}') - sys.exit(1) + buffer = _ANSIBLE_ARGS + profile = _ANSIBLE_PROFILE + + if not profile: + raise Exception("No serialization profile was specified.") try: - return params['ANSIBLE_MODULE_ARGS'] - except KeyError: - # This helper does not have access to fail_json so we have to print - # json output on our own. - print('\n{"msg": "Error: Module unable to locate ANSIBLE_MODULE_ARGS in JSON data from stdin. Unable to figure out what parameters were passed", ' - '"failed": true}') - sys.exit(1) + decoder = _json.get_module_decoder(profile, _json.Direction.CONTROLLER_TO_MODULE) + params = json.loads(buffer.decode(), cls=decoder) + except Exception as ex: + raise Exception("Failed to decode JSON module parameters.") from ex + + if (ansible_module_args := params.get('ANSIBLE_MODULE_ARGS', ...)) is ...: + raise Exception("ANSIBLE_MODULE_ARGS not provided.") + + global _PARSED_MODULE_ARGS + + _PARSED_MODULE_ARGS = copy.deepcopy(ansible_module_args) # AnsibleModule mutates the returned dict, so a copy is needed + + return ansible_module_args def missing_required_lib(library, reason=None, url=None): @@ -506,7 +514,7 @@ class AnsibleModule(object): def deprecate(self, msg, version=None, date=None, collection_name=None): if version is not None and date is not None: raise AssertionError("implementation error -- version and date must not both be set") - deprecate(msg, version=version, date=date, collection_name=collection_name) + deprecate(msg, version=version, date=date) # For compatibility, we accept that neither version nor date is set, # and treat that the same as if version would not have been set if date is not None: @@ -878,8 +886,7 @@ class AnsibleModule(object): raise except Exception as e: path = to_text(b_path) - self.fail_json(path=path, msg='chmod failed', details=to_native(e), - exception=traceback.format_exc()) + self.fail_json(path=path, msg='chmod failed', details=to_native(e)) path_stat = os.lstat(b_path) new_mode = stat.S_IMODE(path_stat.st_mode) @@ -927,8 +934,7 @@ class AnsibleModule(object): if rc != 0 or err: raise Exception("Error while setting attributes: %s" % (out + err)) except Exception as e: - self.fail_json(path=to_text(b_path), msg='chattr failed', - details=to_native(e), exception=traceback.format_exc()) + self.fail_json(path=to_text(b_path), msg='chattr failed', details=to_native(e)) return changed def get_file_attributes(self, path, include_version=True): @@ -1173,8 +1179,7 @@ class AnsibleModule(object): os.environ['LC_ALL'] = best_locale os.environ['LC_MESSAGES'] = best_locale except Exception as e: - self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % - to_native(e), exception=traceback.format_exc()) + self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % to_native(e)) def _set_internal_properties(self, argument_spec=None, module_parameters=None): if argument_spec is None: @@ -1224,7 +1229,6 @@ class AnsibleModule(object): msg='Failed to log to syslog (%s). To proceed anyway, ' 'disable syslog logging by setting no_target_syslog ' 'to True in your Ansible config.' % to_native(e), - exception=traceback.format_exc(), msg_to_log=msg, ) @@ -1378,8 +1382,15 @@ class AnsibleModule(object): self.fail_json(msg=to_native(e)) def jsonify(self, data): + # deprecated: description='deprecate AnsibleModule.jsonify()' core_version='2.23' + # deprecate( + # msg="The `AnsibleModule.jsonify' method is deprecated.", + # version="2.27", + # # help_text="", # DTFIX-RELEASE: fill in this help text + # ) + try: - return jsonify(data) + return json.dumps(data, cls=_common_json._get_legacy_encoder()) except UnicodeError as e: self.fail_json(msg=to_text(e)) @@ -1408,7 +1419,7 @@ class AnsibleModule(object): else: self.warn(kwargs['warnings']) - warnings = get_warning_messages() + warnings = get_warnings() if warnings: kwargs['warnings'] = warnings @@ -1425,7 +1436,7 @@ class AnsibleModule(object): else: self.deprecate(kwargs['deprecations']) # pylint: disable=ansible-deprecated-no-version - deprecations = get_deprecation_messages() + deprecations = get_deprecations() if deprecations: kwargs['deprecations'] = deprecations @@ -1438,7 +1449,8 @@ class AnsibleModule(object): # return preserved kwargs.update(preserved) - print('\n%s' % self.jsonify(kwargs)) + encoder = _json.get_module_encoder(_ANSIBLE_PROFILE, _json.Direction.MODULE_TO_CONTROLLER) + print('\n%s' % json.dumps(kwargs, cls=encoder)) def exit_json(self, **kwargs) -> t.NoReturn: """ return from the module, without error """ @@ -1447,19 +1459,56 @@ class AnsibleModule(object): self._return_formatted(kwargs) sys.exit(0) - def fail_json(self, msg, **kwargs) -> t.NoReturn: - """ return from the module, with an error message """ + def fail_json(self, msg: str, *, exception: BaseException | str | ellipsis | None = ..., **kwargs) -> t.NoReturn: + """ + Return from the module with an error message and optional exception/traceback detail. + A traceback will only be included in the result if error traceback capturing has been enabled. + + When `exception` is an exception object, its message chain will be automatically combined with `msg` to create the final error message. + The message chain includes the exception's message as well as messages from any __cause__ exceptions. + The traceback from `exception` will be used for the formatted traceback. + + When `exception` is a string, it will be used as the formatted traceback. + + When `exception` is set to `None`, the current call stack will be used for the formatted traceback. + + When `exception` is not specified, a formatted traceback will be retrieved from the current exception. + If no exception is pending, the current call stack will be used instead. + """ + msg = str(msg) # coerce to str instead of raising an error due to an invalid type + + kwargs.update( + failed=True, + msg=msg, + ) - kwargs['failed'] = True - kwargs['msg'] = msg + if isinstance(exception, BaseException): + # Include a `_messages.ErrorDetail` in the result. + # The `msg` is included in the list of errors to ensure it is not lost when looking only at `exception` from the result. - # Add traceback if debug or high verbosity and it is missing - # NOTE: Badly named as exception, it really always has been a traceback - if 'exception' not in kwargs and sys.exc_info()[2] and (self._debug or self._verbosity >= 3): - kwargs['exception'] = ''.join(traceback.format_tb(sys.exc_info()[2])) + error_summary = _errors.create_error_summary(exception) + error_summary = _dataclasses.replace(error_summary, details=(_messages.Detail(msg=msg),) + error_summary.details) + + kwargs.update(exception=error_summary) + elif _traceback.is_traceback_enabled(_traceback.TracebackEvent.ERROR): + # Include only a formatted traceback string in the result. + # The controller will combine this with `msg` to create an `_messages.ErrorDetail`. + + formatted_traceback: str | None + + if isinstance(exception, str): + formatted_traceback = exception + elif exception is ... and (current_exception := t.cast(t.Optional[BaseException], sys.exc_info()[1])): + formatted_traceback = _traceback.maybe_extract_traceback(current_exception, _traceback.TracebackEvent.ERROR) + else: + formatted_traceback = _traceback.maybe_capture_traceback(_traceback.TracebackEvent.ERROR) + + if formatted_traceback: + kwargs.update(exception=formatted_traceback) self.do_cleanup_files() self._return_formatted(kwargs) + sys.exit(1) def fail_on_missing_params(self, required_params=None): @@ -1611,7 +1660,7 @@ class AnsibleModule(object): if e.errno not in [errno.EPERM, errno.EXDEV, errno.EACCES, errno.ETXTBSY, errno.EBUSY]: # only try workarounds for errno 18 (cross device), 1 (not permitted), 13 (permission denied) # and 26 (text file busy) which happens on vagrant synced folders and other 'exotic' non posix file systems - self.fail_json(msg='Could not replace file: %s to %s: %s' % (src, dest, to_native(e)), exception=traceback.format_exc()) + self.fail_json(msg='Could not replace file: %s to %s: %s' % (src, dest, to_native(e))) else: # Use bytes here. In the shippable CI, this fails with # a UnicodeError with surrogateescape'd strings for an unknown @@ -1624,12 +1673,11 @@ class AnsibleModule(object): tmp_dest_fd, tmp_dest_name = tempfile.mkstemp(prefix=b'.ansible_tmp', dir=b_dest_dir, suffix=b_suffix) except (OSError, IOError) as e: error_msg = 'The destination directory (%s) is not writable by the current user. Error was: %s' % (os.path.dirname(dest), to_native(e)) - finally: - if error_msg: - if unsafe_writes: - self._unsafe_writes(b_src, b_dest) - else: - self.fail_json(msg=error_msg, exception=traceback.format_exc()) + + if unsafe_writes: + self._unsafe_writes(b_src, b_dest) + else: + self.fail_json(msg=error_msg) if tmp_dest_name: b_tmp_dest_name = to_bytes(tmp_dest_name, errors='surrogate_or_strict') @@ -1668,12 +1716,12 @@ class AnsibleModule(object): self._unsafe_writes(b_tmp_dest_name, b_dest) else: self.fail_json(msg='Unable to make %s into to %s, failed final rename from %s: %s' % - (src, dest, b_tmp_dest_name, to_native(e)), exception=traceback.format_exc()) + (src, dest, b_tmp_dest_name, to_native(e))) except (shutil.Error, OSError, IOError) as e: if unsafe_writes: self._unsafe_writes(b_src, b_dest) else: - self.fail_json(msg='Failed to replace file: %s to %s: %s' % (src, dest, to_native(e)), exception=traceback.format_exc()) + self.fail_json(msg='Failed to replace file: %s to %s: %s' % (src, dest, to_native(e))) finally: self.cleanup(b_tmp_dest_name) @@ -1713,8 +1761,7 @@ class AnsibleModule(object): if in_src: in_src.close() except (shutil.Error, OSError, IOError) as e: - self.fail_json(msg='Could not write data to file (%s) from (%s): %s' % (dest, src, to_native(e)), - exception=traceback.format_exc()) + self.fail_json(msg='Could not write data to file (%s) from (%s): %s' % (dest, src, to_native(e))) def _clean_args(self, args): @@ -2009,7 +2056,7 @@ class AnsibleModule(object): except Exception as e: self.log("Error Executing CMD:%s Exception:%s" % (self._clean_args(args), to_native(traceback.format_exc()))) if handle_exceptions: - self.fail_json(rc=257, stdout=b'', stderr=b'', msg=to_native(e), exception=traceback.format_exc(), cmd=self._clean_args(args)) + self.fail_json(rc=257, stdout=b'', stderr=b'', msg=to_native(e), cmd=self._clean_args(args)) else: raise e diff --git a/lib/ansible/module_utils/common/_utils.py b/lib/ansible/module_utils/common/_utils.py index deab1fcdf9c..51af1e69e16 100644 --- a/lib/ansible/module_utils/common/_utils.py +++ b/lib/ansible/module_utils/common/_utils.py @@ -1,38 +1,34 @@ # Copyright (c) 2018, Ansible Project # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) - - """ Modules in _utils are waiting to find a better home. If you need to use them, be prepared for them to move to a different location in the future. """ + from __future__ import annotations +import inspect +import typing as t + +_Type = t.TypeVar('_Type') + + +def get_all_subclasses(cls: type[_Type], *, include_abstract: bool = True, consider_self: bool = False) -> set[type[_Type]]: + """Recursively find all subclasses of a given type, including abstract classes by default.""" + subclasses: set[type[_Type]] = {cls} if consider_self else set() + queue: list[type[_Type]] = [cls] + + while queue: + parent = queue.pop() + + for child in parent.__subclasses__(): + if child in subclasses: + continue + + queue.append(child) + subclasses.add(child) + + if not include_abstract: + subclasses = {sc for sc in subclasses if not inspect.isabstract(sc)} -def get_all_subclasses(cls): - """ - Recursively search and find all subclasses of a given class - - :arg cls: A python class - :rtype: set - :returns: The set of python classes which are the subclasses of `cls`. - - In python, you can use a class's :py:meth:`__subclasses__` method to determine what subclasses - of a class exist. However, `__subclasses__` only goes one level deep. This function searches - each child class's `__subclasses__` method to find all of the descendent classes. It then - returns an iterable of the descendent classes. - """ - # Retrieve direct subclasses - subclasses = set(cls.__subclasses__()) - to_visit = list(subclasses) - # Then visit all subclasses - while to_visit: - for sc in to_visit: - # The current class is now visited, so remove it from list - to_visit.remove(sc) - # Appending all subclasses to visit and keep a reference of available class - for ssc in sc.__subclasses__(): - if ssc not in subclasses: - to_visit.append(ssc) - subclasses.add(ssc) return subclasses diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py index 28c53e14e2c..f5fae55aa8d 100644 --- a/lib/ansible/module_utils/common/collections.py +++ b/lib/ansible/module_utils/common/collections.py @@ -66,8 +66,7 @@ class ImmutableDict(Hashable, Mapping): def is_string(seq): """Identify whether the input has a string-like type (including bytes).""" - # AnsibleVaultEncryptedUnicode inherits from Sequence, but is expected to be a string like object - return isinstance(seq, (text_type, binary_type)) or getattr(seq, '__ENCRYPTED__', False) + return isinstance(seq, (text_type, binary_type)) def is_iterable(seq, include_strings=False): diff --git a/lib/ansible/module_utils/common/json.py b/lib/ansible/module_utils/common/json.py index fe65a8d701c..3b38c421d05 100644 --- a/lib/ansible/module_utils/common/json.py +++ b/lib/ansible/module_utils/common/json.py @@ -1,84 +1,90 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2019 Ansible Project -# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) - -from __future__ import annotations - -import json - -import datetime - -from ansible.module_utils.common.text.converters import to_text -from ansible.module_utils.six.moves.collections_abc import Mapping -from ansible.module_utils.common.collections import is_sequence - - -def _is_unsafe(value): - return getattr(value, '__UNSAFE__', False) and not getattr(value, '__ENCRYPTED__', False) - - -def _is_vault(value): - return getattr(value, '__ENCRYPTED__', False) - - -def _preprocess_unsafe_encode(value): - """Recursively preprocess a data structure converting instances of ``AnsibleUnsafe`` - into their JSON dict representations - - Used in ``AnsibleJSONEncoder.iterencode`` - """ - if _is_unsafe(value): - value = {'__ansible_unsafe': to_text(value, errors='surrogate_or_strict', nonstring='strict')} - elif is_sequence(value): - value = [_preprocess_unsafe_encode(v) for v in value] - elif isinstance(value, Mapping): - value = dict((k, _preprocess_unsafe_encode(v)) for k, v in value.items()) - - return value - - -def json_dump(structure): - return json.dumps(structure, cls=AnsibleJSONEncoder, sort_keys=True, indent=4) - - -class AnsibleJSONEncoder(json.JSONEncoder): - """ - Simple encoder class to deal with JSON encoding of Ansible internal types - """ - - def __init__(self, preprocess_unsafe=False, vault_to_text=False, **kwargs): - self._preprocess_unsafe = preprocess_unsafe - self._vault_to_text = vault_to_text - super(AnsibleJSONEncoder, self).__init__(**kwargs) - - # NOTE: ALWAYS inform AWS/Tower when new items get added as they consume them downstream via a callback - def default(self, o): - if getattr(o, '__ENCRYPTED__', False): - # vault object - if self._vault_to_text: - value = to_text(o, errors='surrogate_or_strict') - else: - value = {'__ansible_vault': to_text(o._ciphertext, errors='surrogate_or_strict', nonstring='strict')} - elif getattr(o, '__UNSAFE__', False): - # unsafe object, this will never be triggered, see ``AnsibleJSONEncoder.iterencode`` - value = {'__ansible_unsafe': to_text(o, errors='surrogate_or_strict', nonstring='strict')} - elif isinstance(o, Mapping): - # hostvars and other objects - value = dict(o) - elif isinstance(o, (datetime.date, datetime.datetime)): - # date object - value = o.isoformat() - else: - # use default encoder - value = super(AnsibleJSONEncoder, self).default(o) - return value - - def iterencode(self, o, **kwargs): - """Custom iterencode, primarily design to handle encoding ``AnsibleUnsafe`` - as the ``AnsibleUnsafe`` subclasses inherit from string types and - ``json.JSONEncoder`` does not support custom encoders for string types - """ - if self._preprocess_unsafe: - o = _preprocess_unsafe_encode(o) - - return super(AnsibleJSONEncoder, self).iterencode(o, **kwargs) +from __future__ import annotations as _annotations + +import enum as _enum +import json as _stdlib_json +import types as _types + +from ansible.module_utils import _internal +from ansible.module_utils._internal import _json +from ansible.module_utils._internal._json import _legacy_encoder +from ansible.module_utils._internal._json import _profiles +from ansible.module_utils._internal._json._profiles import _tagless +from ansible.module_utils.common import warnings as _warnings + + +def __getattr__(name: str) -> object: + """Handle dynamic module members which are or will be deprecated.""" + if name in ('AnsibleJSONEncoder', '_AnsibleJSONEncoder'): + # deprecated: description='deprecate legacy encoder' core_version='2.23' + # if not name.startswith('_'): # avoid duplicate deprecation warning for imports from ajson + # _warnings.deprecate( + # msg="The `AnsibleJSONEncoder` type is deprecated.", + # version="2.27", + # help_text="Use a profile-based encoder instead.", # DTFIX-RELEASE: improve this help text + # ) + + return _get_legacy_encoder() + + if name in ('AnsibleJSONDecoder', '_AnsibleJSONDecoder'): + # deprecated: description='deprecate legacy decoder' core_version='2.23' + # if not name.startswith('_'): # avoid duplicate deprecation warning for imports from ajson + # _warnings.deprecate( + # msg="The `AnsibleJSONDecoder` type is deprecated.", + # version="2.27", + # help_text="Use a profile-based decoder instead.", # DTFIX-RELEASE: improve this help text + # ) + + return _tagless.Decoder + + if name == 'json_dump': + _warnings.deprecate( + msg="The `json_dump` function is deprecated.", + version="2.23", + help_text="Use `json.dumps` with the appropriate `cls` instead.", + ) + + return _json_dump + + raise AttributeError(name) + + +def _get_legacy_encoder() -> type[_stdlib_json.JSONEncoder]: + """Compatibility hack: previous module_utils AnsibleJSONEncoder impl did controller-side work, controller plugins require a more fully-featured impl.""" + if _internal.is_controller: + return _internal.import_controller_module('ansible._internal._json._legacy_encoder').LegacyControllerJSONEncoder + + return _legacy_encoder.LegacyTargetJSONEncoder + + +def _json_dump(structure): + """JSON dumping function maintained for temporary backward compatibility.""" + return _stdlib_json.dumps(structure, cls=_get_legacy_encoder(), sort_keys=True, indent=4) + + +class Direction(_enum.Enum): + """Enumeration used to select a contextually-appropriate JSON profile for module messaging.""" + + CONTROLLER_TO_MODULE = _enum.auto() + """Encode/decode messages from the Ansible controller to an Ansible module.""" + MODULE_TO_CONTROLLER = _enum.auto() + """Encode/decode messages from an Ansible module to the Ansible controller.""" + + +def get_encoder(profile: str | _types.ModuleType, /) -> type[_stdlib_json.JSONEncoder]: + """Return a `JSONEncoder` for the given `profile`.""" + return _json.get_encoder_decoder(profile, _profiles.AnsibleProfileJSONEncoder) + + +def get_decoder(profile: str | _types.ModuleType, /) -> type[_stdlib_json.JSONDecoder]: + """Return a `JSONDecoder` for the given `profile`.""" + return _json.get_encoder_decoder(profile, _profiles.AnsibleProfileJSONDecoder) + + +def get_module_encoder(name: str, direction: Direction, /) -> type[_stdlib_json.JSONEncoder]: + """Return a `JSONEncoder` for the module profile specified by `name` and `direction`.""" + return get_encoder(_json.get_module_serialization_profile_name(name, direction == Direction.CONTROLLER_TO_MODULE)) + + +def get_module_decoder(name: str, direction: Direction, /) -> type[_stdlib_json.JSONDecoder]: + """Return a `JSONDecoder` for the module profile specified by `name` and `direction`.""" + return get_decoder(_json.get_module_serialization_profile_name(name, direction == Direction.CONTROLLER_TO_MODULE)) diff --git a/lib/ansible/module_utils/common/messages.py b/lib/ansible/module_utils/common/messages.py new file mode 100644 index 00000000000..a4ec12f8494 --- /dev/null +++ b/lib/ansible/module_utils/common/messages.py @@ -0,0 +1,108 @@ +""" +Message contract definitions for various target-side types. + +These types and the wire format they implement are currently considered provisional and subject to change without notice. +A future release will remove the provisional status. +""" + +from __future__ import annotations as _annotations + +import sys as _sys +import dataclasses as _dataclasses + +# deprecated: description='typing.Self exists in Python 3.11+' python_version='3.10' +from ..compat import typing as _t + +from ansible.module_utils._internal import _datatag + +if _sys.version_info >= (3, 10): + # Using slots for reduced memory usage and improved performance. + _dataclass_kwargs = dict(frozen=True, kw_only=True, slots=True) +else: + # deprecated: description='always use dataclass slots and keyword-only args' python_version='3.9' + _dataclass_kwargs = dict(frozen=True) + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class PluginInfo(_datatag.AnsibleSerializableDataclass): + """Information about a loaded plugin.""" + + requested_name: str + """The plugin name as requested, before resolving, which may be partially or fully qualified.""" + resolved_name: str + """The resolved canonical plugin name; always fully-qualified for collection plugins.""" + type: str + """The plugin type.""" + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class Detail(_datatag.AnsibleSerializableDataclass): + """Message detail with optional source context and help text.""" + + msg: str + formatted_source_context: _t.Optional[str] = None + help_text: _t.Optional[str] = None + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class SummaryBase(_datatag.AnsibleSerializableDataclass): + """Base class for an error/warning/deprecation summary with details (possibly derived from an exception __cause__ chain) and an optional traceback.""" + + details: _t.Tuple[Detail, ...] + formatted_traceback: _t.Optional[str] = None + + def _format(self) -> str: + """Returns a string representation of the details.""" + # DTFIX-RELEASE: eliminate this function and use a common message squashing utility such as get_chained_message on instances of this type + return ': '.join(detail.msg for detail in self.details) + + def _post_validate(self) -> None: + if not self.details: + raise ValueError(f'{type(self).__name__}.details cannot be empty') + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class ErrorSummary(SummaryBase): + """Error summary with details (possibly derived from an exception __cause__ chain) and an optional traceback.""" + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class WarningSummary(SummaryBase): + """Warning summary with details (possibly derived from an exception __cause__ chain) and an optional traceback.""" + + +@_dataclasses.dataclass(**_dataclass_kwargs) +class DeprecationSummary(WarningSummary): + """Deprecation summary with details (possibly derived from an exception __cause__ chain) and an optional traceback.""" + + version: _t.Optional[str] = None + date: _t.Optional[str] = None + plugin: _t.Optional[PluginInfo] = None + + @property + def collection_name(self) -> _t.Optional[str]: + if not self.plugin: + return None + + parts = self.plugin.resolved_name.split('.') + + if len(parts) < 2: + return None + + collection_name = '.'.join(parts[:2]) + + # deprecated: description='enable the deprecation message for collection_name' core_version='2.23' + # from ansible.module_utils.datatag import deprecate_value + # collection_name = deprecate_value(collection_name, 'The `collection_name` property is deprecated.', removal_version='2.27') + + return collection_name + + def _as_simple_dict(self) -> _t.Dict[str, _t.Any]: + """Returns a dictionary representation of the deprecation object in the format exposed to playbooks.""" + result = self._as_dict() + result.update( + msg=self._format(), + collection_name=self.collection_name, + ) + + return result diff --git a/lib/ansible/module_utils/common/parameters.py b/lib/ansible/module_utils/common/parameters.py index c80ca6ccf16..fc886463c94 100644 --- a/lib/ansible/module_utils/common/parameters.py +++ b/lib/ansible/module_utils/common/parameters.py @@ -6,13 +6,16 @@ from __future__ import annotations import datetime import os +import typing as t from collections import deque from itertools import chain from ansible.module_utils.common.collections import is_iterable +from ansible.module_utils._internal._datatag import AnsibleSerializable, AnsibleTagHelper from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.warnings import warn +from ansible.module_utils.datatag import native_type_name from ansible.module_utils.errors import ( AliasError, AnsibleFallbackNotFound, @@ -83,7 +86,7 @@ _ADDITIONAL_CHECKS = ( # if adding boolean attribute, also add to PASS_BOOL # some of this dupes defaults from controller config # keep in sync with copy in lib/ansible/module_utils/csharp/Ansible.Basic.cs -PASS_VARS = { +PASS_VARS: dict[str, t.Any] = { 'check_mode': ('check_mode', False), 'debug': ('_debug', False), 'diff': ('_diff', False), @@ -98,6 +101,7 @@ PASS_VARS = { 'socket': ('_socket_path', None), 'syslog_facility': ('_syslog_facility', 'INFO'), 'tmpdir': ('_tmpdir', None), + 'tracebacks_for': ('_tracebacks_for', frozenset()), 'verbosity': ('_verbosity', 0), 'version': ('ansible_version', '0.0'), } @@ -407,6 +411,8 @@ def _remove_values_conditions(value, no_log_strings, deferred_removals): dictionary for ``level1``, then the dict for ``level2``, and finally the list for ``level3``. """ + original_value = value + if isinstance(value, (text_type, binary_type)): # Need native str type native_str_value = value @@ -431,31 +437,25 @@ def _remove_values_conditions(value, no_log_strings, deferred_removals): else: value = native_str_value + elif value is True or value is False or value is None: + return value + elif isinstance(value, Sequence): - if isinstance(value, MutableSequence): - new_value = type(value)() - else: - new_value = [] # Need a mutable value + new_value = AnsibleTagHelper.tag_copy(original_value, []) deferred_removals.append((value, new_value)) - value = new_value + return new_value elif isinstance(value, Set): - if isinstance(value, MutableSet): - new_value = type(value)() - else: - new_value = set() # Need a mutable value + new_value = AnsibleTagHelper.tag_copy(original_value, set()) deferred_removals.append((value, new_value)) - value = new_value + return new_value elif isinstance(value, Mapping): - if isinstance(value, MutableMapping): - new_value = type(value)() - else: - new_value = {} # Need a mutable value + new_value = AnsibleTagHelper.tag_copy(original_value, {}) deferred_removals.append((value, new_value)) - value = new_value + return new_value - elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + elif isinstance(value, (int, float)): stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict') if stringy_value in no_log_strings: return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' @@ -463,11 +463,15 @@ def _remove_values_conditions(value, no_log_strings, deferred_removals): if omit_me in stringy_value: return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' - elif isinstance(value, (datetime.datetime, datetime.date)): - value = value.isoformat() + elif isinstance(value, (datetime.datetime, datetime.date, datetime.time)): + return value + elif isinstance(value, AnsibleSerializable): + return value else: raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) + value = AnsibleTagHelper.tag_copy(original_value, value) + return value @@ -540,7 +544,7 @@ def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_remov if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): return value - if isinstance(value, (datetime.datetime, datetime.date)): + if isinstance(value, (datetime.datetime, datetime.date, datetime.time)): return value raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) @@ -569,7 +573,7 @@ def _validate_elements(wanted_type, parameter, values, options_context=None, err msg = "Elements value for option '%s'" % parameter if options_context: msg += " found in '%s'" % " -> ".join(options_context) - msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e)) + msg += " is of type %s and we were unable to convert to %s: %s" % (native_type_name(value), wanted_element_type, to_native(e)) errors.append(ElementError(msg)) return validated_parameters @@ -628,7 +632,7 @@ def _validate_argument_types(argument_spec, parameters, prefix='', options_conte elements_wanted_type = spec.get('elements', None) if elements_wanted_type: elements = parameters[param] - if wanted_type != 'list' or not isinstance(elements, list): + if not isinstance(parameters[param], list) or not isinstance(elements, list): msg = "Invalid type %s for option '%s'" % (wanted_name, elements) if options_context: msg += " found in '%s'." % " -> ".join(options_context) @@ -637,7 +641,7 @@ def _validate_argument_types(argument_spec, parameters, prefix='', options_conte parameters[param] = _validate_elements(elements_wanted_type, param, elements, options_context, errors) except (TypeError, ValueError) as e: - msg = "argument '%s' is of type %s" % (param, type(value)) + msg = "argument '%s' is of type %s" % (param, native_type_name(value)) if options_context: msg += " found in '%s'." % " -> ".join(options_context) msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) diff --git a/lib/ansible/module_utils/common/respawn.py b/lib/ansible/module_utils/common/respawn.py index d16815b9a17..c0874fb2911 100644 --- a/lib/ansible/module_utils/common/respawn.py +++ b/lib/ansible/module_utils/common/respawn.py @@ -3,12 +3,14 @@ from __future__ import annotations +import dataclasses import os import pathlib import subprocess import sys import typing as t +from ansible.module_utils._internal import _plugin_exec_context from ansible.module_utils.common.text.converters import to_bytes _ANSIBLE_PARENT_PATH = pathlib.Path(__file__).parents[3] @@ -84,29 +86,45 @@ def probe_interpreters_for_module(interpreter_paths, module_name): def _create_payload(): + # FIXME: move this into _ansiballz and skip the template from ansible.module_utils import basic - smuggled_args = getattr(basic, '_ANSIBLE_ARGS') - if not smuggled_args: - raise Exception('unable to access ansible.module_utils.basic._ANSIBLE_ARGS (not launched by AnsiballZ?)') + module_fqn = sys.modules['__main__']._module_fqn modlib_path = sys.modules['__main__']._modlib_path - respawn_code_template = """ -import runpy -import sys - -module_fqn = {module_fqn!r} -modlib_path = {modlib_path!r} -smuggled_args = {smuggled_args!r} + respawn_code_template = """ if __name__ == '__main__': - sys.path.insert(0, modlib_path) + import runpy + import sys - from ansible.module_utils import basic - basic._ANSIBLE_ARGS = smuggled_args + json_params = {json_params!r} + profile = {profile!r} + plugin_info_dict = {plugin_info_dict!r} + module_fqn = {module_fqn!r} + modlib_path = {modlib_path!r} - runpy.run_module(module_fqn, init_globals=dict(_respawned=True), run_name='__main__', alter_sys=True) - """ + sys.path.insert(0, modlib_path) - respawn_code = respawn_code_template.format(module_fqn=module_fqn, modlib_path=modlib_path, smuggled_args=smuggled_args.strip()) + from ansible.module_utils._internal import _ansiballz + + _ansiballz.run_module( + json_params=json_params, + profile=profile, + plugin_info_dict=plugin_info_dict, + module_fqn=module_fqn, + modlib_path=modlib_path, + init_globals=dict(_respawned=True), + ) +""" + + plugin_info = _plugin_exec_context.PluginExecContext.get_current_plugin_info() + + respawn_code = respawn_code_template.format( + json_params=basic._ANSIBLE_ARGS, + profile=basic._ANSIBLE_PROFILE, + plugin_info_dict=dataclasses.asdict(plugin_info), + module_fqn=module_fqn, + modlib_path=modlib_path, + ) return respawn_code diff --git a/lib/ansible/module_utils/common/text/converters.py b/lib/ansible/module_utils/common/text/converters.py index 6bfa8470b69..78fb96ec282 100644 --- a/lib/ansible/module_utils/common/text/converters.py +++ b/lib/ansible/module_utils/common/text/converters.py @@ -6,12 +6,9 @@ from __future__ import annotations import codecs -import datetime import json -from ansible.module_utils.six.moves.collections_abc import Set from ansible.module_utils.six import ( - PY3, binary_type, iteritems, text_type, @@ -237,44 +234,21 @@ def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): return to_text(value, encoding, errors) -#: :py:func:`to_native` -#: Transform a variable into the native str type for the python version -#: -#: On Python2, this is an alias for -#: :func:`~ansible.module_utils.to_bytes`. On Python3 it is an alias for -#: :func:`~ansible.module_utils.to_text`. It makes it easier to -#: transform a variable into the native str type for the python version -#: the code is running on. Use this when constructing the message to -#: send to exceptions or when dealing with an API that needs to take -#: a native string. Example:: -#: -#: try: -#: 1//0 -#: except ZeroDivisionError as e: -#: raise MyException('Encountered and error: %s' % to_native(e)) -if PY3: - to_native = to_text -else: - to_native = to_bytes - - -def _json_encode_fallback(obj): - if isinstance(obj, Set): - return list(obj) - elif isinstance(obj, datetime.datetime): - return obj.isoformat() - raise TypeError("Cannot json serialize %s" % to_native(obj)) +to_native = to_text def jsonify(data, **kwargs): - # After 2.18, we should remove this loop, and hardcode to utf-8 in alignment with requiring utf-8 module responses - for encoding in ("utf-8", "latin-1"): - try: - new_data = container_to_text(data, encoding=encoding) - except UnicodeDecodeError: - continue - return json.dumps(new_data, default=_json_encode_fallback, **kwargs) - raise UnicodeError('Invalid unicode encoding encountered') + from ansible.module_utils.common import json as _common_json + # from ansible.module_utils.common.warnings import deprecate + + # deprecated: description='deprecate jsonify()' core_version='2.23' + # deprecate( + # msg="The `jsonify` function is deprecated.", + # version="2.27", + # # help_text="", # DTFIX-RELEASE: fill in this help text + # ) + + return json.dumps(data, cls=_common_json._get_legacy_encoder(), _decode_bytes=True, **kwargs) def container_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): @@ -283,6 +257,7 @@ def container_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): Specialized for json return because this only handles, lists, tuples, and dict container types (the containers that the json module returns) """ + # DTFIX-RELEASE: deprecate if isinstance(d, text_type): return to_bytes(d, encoding=encoding, errors=errors) @@ -302,6 +277,7 @@ def container_to_text(d, encoding='utf-8', errors='surrogate_or_strict'): Specialized for json return because this only handles, lists, tuples, and dict container types (the containers that the json module returns) """ + # DTFIX-RELEASE: deprecate if isinstance(d, binary_type): # Warning, can traceback diff --git a/lib/ansible/module_utils/common/validation.py b/lib/ansible/module_utils/common/validation.py index 1098f27336e..952b991395f 100644 --- a/lib/ansible/module_utils/common/validation.py +++ b/lib/ansible/module_utils/common/validation.py @@ -10,16 +10,14 @@ import os import re from ast import literal_eval +from ansible.module_utils.common import json as _common_json from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.collections import is_iterable -from ansible.module_utils.common.text.converters import jsonify from ansible.module_utils.common.text.formatters import human_to_bytes from ansible.module_utils.common.warnings import deprecate from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.six import ( - binary_type, string_types, - text_type, ) @@ -385,6 +383,10 @@ def check_type_str(value, allow_conversion=True, param=None, prefix=''): raise TypeError(to_native(msg)) +def _check_type_str_no_conversion(value) -> str: + return check_type_str(value, allow_conversion=False) + + def check_type_list(value): """Verify that the value is a list or convert to a list @@ -400,6 +402,7 @@ def check_type_list(value): if isinstance(value, list): return value + # DTFIX-RELEASE: deprecate legacy comma split functionality, eventually replace with `_check_type_list_strict` if isinstance(value, string_types): return value.split(",") elif isinstance(value, int) or isinstance(value, float): @@ -408,6 +411,14 @@ def check_type_list(value): raise TypeError('%s cannot be converted to a list' % type(value)) +def _check_type_list_strict(value): + # FUTURE: this impl should replace `check_type_list` + if isinstance(value, list): + return value + + return [value] + + def check_type_dict(value): """Verify that value is a dict or convert it to a dict and return it. @@ -565,14 +576,21 @@ def check_type_bits(value): def check_type_jsonarg(value): - """Return a jsonified string. Sometimes the controller turns a json string - into a dict/list so transform it back into json here - - Raises :class:`TypeError` if unable to convert the value - """ - if isinstance(value, (text_type, binary_type)): + JSON serialize dict/list/tuple, strip str and bytes. + Previously required for cases where Ansible/Jinja classic-mode literal eval pass could inadvertently deserialize objects. + """ + # deprecated: description='deprecate jsonarg type support' core_version='2.23' + # deprecate( + # msg="The `jsonarg` type is deprecated.", + # version="2.27", + # help_text="JSON string arguments should use `str`; structures can be explicitly serialized as JSON with the `to_json` filter.", + # ) + + if isinstance(value, (str, bytes)): return value.strip() - elif isinstance(value, (list, tuple, dict)): - return jsonify(value) + + if isinstance(value, (list, tuple, dict)): + return json.dumps(value, cls=_common_json._get_legacy_encoder(), _decode_bytes=True) + raise TypeError('%s cannot be converted to a json string' % type(value)) diff --git a/lib/ansible/module_utils/common/warnings.py b/lib/ansible/module_utils/common/warnings.py index 14fe516cf5b..fb10b7897d4 100644 --- a/lib/ansible/module_utils/common/warnings.py +++ b/lib/ansible/module_utils/common/warnings.py @@ -2,38 +2,99 @@ # Copyright (c) 2019 Ansible Project # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) -from __future__ import annotations +from __future__ import annotations as _annotations -from ansible.module_utils.six import string_types +import datetime as _datetime +import typing as _t -_global_warnings = [] -_global_deprecations = [] +from ansible.module_utils._internal import _traceback, _plugin_exec_context +from ansible.module_utils.common import messages as _messages +from ansible.module_utils import _internal +_UNSET = _t.cast(_t.Any, ...) -def warn(warning): - if isinstance(warning, string_types): - _global_warnings.append(warning) - else: - raise TypeError("warn requires a string not a %s" % type(warning)) +def warn(warning: str) -> None: + """Record a warning to be returned with the module result.""" + # DTFIX-RELEASE: shim to controller display warning like `deprecate` + _global_warnings[_messages.WarningSummary( + details=( + _messages.Detail(msg=warning), + ), + formatted_traceback=_traceback.maybe_capture_traceback(_traceback.TracebackEvent.WARNING), + )] = None -def deprecate(msg, version=None, date=None, collection_name=None): - if isinstance(msg, string_types): - # For compatibility, we accept that neither version nor date is set, - # and treat that the same as if version would haven been set - if date is not None: - _global_deprecations.append({'msg': msg, 'date': date, 'collection_name': collection_name}) - else: - _global_deprecations.append({'msg': msg, 'version': version, 'collection_name': collection_name}) - else: - raise TypeError("deprecate requires a string not a %s" % type(msg)) +def deprecate( + msg: str, + version: str | None = None, + date: str | _datetime.date | None = None, + collection_name: str | None = _UNSET, + *, + help_text: str | None = None, + obj: object | None = None, +) -> None: + """ + Record a deprecation warning to be returned with the module result. + The `obj` argument is only useful in a controller context; it is ignored for target-side callers. + """ + if isinstance(date, _datetime.date): + date = str(date) -def get_warning_messages(): - """Return a tuple of warning messages accumulated over this run""" - return tuple(_global_warnings) + # deprecated: description='enable the deprecation message for collection_name' core_version='2.23' + # if collection_name is not _UNSET: + # deprecate('The `collection_name` argument to `deprecate` is deprecated.', version='2.27') + if _internal.is_controller: + _display = _internal.import_controller_module('ansible.utils.display').Display() + _display.deprecated( + msg=msg, + version=version, + date=date, + help_text=help_text, + obj=obj, + ) -def get_deprecation_messages(): - """Return a tuple of deprecations accumulated over this run""" - return tuple(_global_deprecations) + return + + _global_deprecations[_messages.DeprecationSummary( + details=( + _messages.Detail(msg=msg, help_text=help_text), + ), + formatted_traceback=_traceback.maybe_capture_traceback(_traceback.TracebackEvent.DEPRECATED), + version=version, + date=date, + plugin=_plugin_exec_context.PluginExecContext.get_current_plugin_info(), + )] = None + + +def get_warning_messages() -> tuple[str, ...]: + """Return a tuple of warning messages accumulated over this run.""" + # DTFIX-RELEASE: add future deprecation comment + return tuple(item._format() for item in _global_warnings) + + +_DEPRECATION_MESSAGE_KEYS = frozenset({'msg', 'date', 'version', 'collection_name'}) + + +def get_deprecation_messages() -> tuple[dict[str, _t.Any], ...]: + """Return a tuple of deprecation warning messages accumulated over this run.""" + # DTFIX-RELEASE: add future deprecation comment + return tuple({key: value for key, value in item._as_simple_dict().items() if key in _DEPRECATION_MESSAGE_KEYS} for item in _global_deprecations) + + +def get_warnings() -> list[_messages.WarningSummary]: + """Return a list of warning messages accumulated over this run.""" + return list(_global_warnings) + + +def get_deprecations() -> list[_messages.DeprecationSummary]: + """Return a list of deprecations accumulated over this run.""" + return list(_global_deprecations) + + +_global_warnings: dict[_messages.WarningSummary, object] = {} +"""Global, ordered, de-deplicated storage of acculumated warnings for the current module run.""" + +_global_deprecations: dict[_messages.DeprecationSummary, object] = {} +"""Global, ordered, de-deplicated storage of acculumated deprecations for the current module run.""" diff --git a/lib/ansible/module_utils/common/yaml.py b/lib/ansible/module_utils/common/yaml.py index 2e1ee52dc0b..838722b6fb4 100644 --- a/lib/ansible/module_utils/common/yaml.py +++ b/lib/ansible/module_utils/common/yaml.py @@ -6,10 +6,15 @@ This file provides ease of use shortcuts for loading and dumping YAML, preferring the YAML compiled C extensions to reduce duplicated code. """ -from __future__ import annotations +from __future__ import annotations as _annotations + +import collections.abc as _c +import typing as _t from functools import partial as _partial +from .._internal import _datatag + HAS_LIBYAML = False try: @@ -19,23 +24,44 @@ except ImportError: else: HAS_YAML = True +# DTFIX-RELEASE: refactor this to share the implementation with the controller version +# use an abstract base class, with __init_subclass__ for representer registration, and instance methods for overridable representers +# then tests can be consolidated intead of having two nearly identical copies + if HAS_YAML: try: from yaml import CSafeLoader as SafeLoader from yaml import CSafeDumper as SafeDumper + from yaml.representer import SafeRepresenter from yaml.cyaml import CParser as Parser # type: ignore[attr-defined] # pylint: disable=unused-import HAS_LIBYAML = True except (ImportError, AttributeError): from yaml import SafeLoader # type: ignore[assignment] from yaml import SafeDumper # type: ignore[assignment] + from yaml.representer import SafeRepresenter # type: ignore[assignment] from yaml.parser import Parser # type: ignore[assignment] # pylint: disable=unused-import + class _AnsibleDumper(SafeDumper): + pass + + def _represent_ansible_tagged_object(self, data: _datatag.AnsibleTaggedObject) -> _t.Any: + return self.represent_data(_datatag.AnsibleTagHelper.as_native_type(data)) + + def _represent_tripwire(self, data: _datatag.Tripwire) -> _t.NoReturn: + data.trip() + + _AnsibleDumper.add_multi_representer(_datatag.AnsibleTaggedObject, _represent_ansible_tagged_object) + + _AnsibleDumper.add_multi_representer(_datatag.Tripwire, _represent_tripwire) + _AnsibleDumper.add_multi_representer(_c.Mapping, SafeRepresenter.represent_dict) + _AnsibleDumper.add_multi_representer(_c.Sequence, SafeRepresenter.represent_list) + yaml_load = _partial(_yaml.load, Loader=SafeLoader) yaml_load_all = _partial(_yaml.load_all, Loader=SafeLoader) - yaml_dump = _partial(_yaml.dump, Dumper=SafeDumper) - yaml_dump_all = _partial(_yaml.dump_all, Dumper=SafeDumper) + yaml_dump = _partial(_yaml.dump, Dumper=_AnsibleDumper) + yaml_dump_all = _partial(_yaml.dump_all, Dumper=_AnsibleDumper) else: SafeLoader = object # type: ignore[assignment,misc] SafeDumper = object # type: ignore[assignment,misc] diff --git a/lib/ansible/module_utils/compat/paramiko.py b/lib/ansible/module_utils/compat/paramiko.py index bf2584d8fee..f654229580d 100644 --- a/lib/ansible/module_utils/compat/paramiko.py +++ b/lib/ansible/module_utils/compat/paramiko.py @@ -2,29 +2,36 @@ # Copyright (c) 2019 Ansible Project # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) -from __future__ import annotations +from __future__ import annotations as _annotations -import types # pylint: disable=unused-import -import warnings +import warnings as _warnings -from ansible.module_utils.common.warnings import deprecate +from ansible.module_utils.common.warnings import deprecate as _deprecate -PARAMIKO_IMPORT_ERR = None +_PARAMIKO_IMPORT_ERR = None try: - with warnings.catch_warnings(): + with _warnings.catch_warnings(): # Blowfish has been moved, but the deprecated import is used by paramiko versions older than 2.9.5. # See: https://github.com/paramiko/paramiko/pull/2039 - warnings.filterwarnings('ignore', message='Blowfish has been ', category=UserWarning) + _warnings.filterwarnings('ignore', message='Blowfish has been ', category=UserWarning) # TripleDES has been moved, but the deprecated import is used by paramiko versions older than 3.3.2 and 3.4.1. # See: https://github.com/paramiko/paramiko/pull/2421 - warnings.filterwarnings('ignore', message='TripleDES has been ', category=UserWarning) - import paramiko # pylint: disable=unused-import + _warnings.filterwarnings('ignore', message='TripleDES has been ', category=UserWarning) + import paramiko as _paramiko # paramiko and gssapi are incompatible and raise AttributeError not ImportError # When running in FIPS mode, cryptography raises InternalError # https://bugzilla.redhat.com/show_bug.cgi?id=1778939 except Exception as err: - paramiko = None # type: types.ModuleType | None # type: ignore[no-redef] - PARAMIKO_IMPORT_ERR = err + _paramiko = None # type: ignore[no-redef] + _PARAMIKO_IMPORT_ERR = err -deprecate('The paramiko compat import is deprecated', version='2.21') + +def __getattr__(name: str) -> object: + """Dynamic lookup to issue deprecation warnings for external import of deprecated items.""" + if (res := globals().get(f'_{name}', ...)) is not ...: + _deprecate(f'The {name!r} compat import is deprecated.', version='2.21') + + return res + + raise AttributeError(name) diff --git a/lib/ansible/module_utils/compat/typing.py b/lib/ansible/module_utils/compat/typing.py index d753f72b25e..af118bc723e 100644 --- a/lib/ansible/module_utils/compat/typing.py +++ b/lib/ansible/module_utils/compat/typing.py @@ -6,6 +6,8 @@ from __future__ import annotations # catch *all* exceptions to prevent type annotation support module bugs causing runtime failures # (eg, https://github.com/ansible/ansible/issues/77857) +TYPE_CHECKING = False + try: from typing_extensions import * except Exception: # pylint: disable=broad-except @@ -17,8 +19,7 @@ except Exception: # pylint: disable=broad-except pass -try: - cast # type: ignore[used-before-def] -except NameError: - def cast(typ, val): # type: ignore[no-redef] - return val +# this import and patch occur after typing_extensions/typing imports since the presence of those modules affects dataclasses behavior +from .._internal._patches import _dataclass_annotation_patch + +_dataclass_annotation_patch.DataclassesIsTypePatch.patch() diff --git a/lib/ansible/module_utils/connection.py b/lib/ansible/module_utils/connection.py index b6720125855..19b38b73815 100644 --- a/lib/ansible/module_utils/connection.py +++ b/lib/ansible/module_utils/connection.py @@ -38,7 +38,7 @@ import uuid from functools import partial from ansible.module_utils.common.text.converters import to_bytes, to_text -from ansible.module_utils.common.json import AnsibleJSONEncoder +from ansible.module_utils.common.json import _get_legacy_encoder from ansible.module_utils.six import iteritems @@ -127,7 +127,7 @@ class Connection(object): ) try: - data = json.dumps(req, cls=AnsibleJSONEncoder, vault_to_text=True) + data = json.dumps(req, cls=_get_legacy_encoder(), vault_to_text=True) except TypeError as exc: raise ConnectionError( "Failed to encode some variables as JSON for communication with the persistent connection helper. " diff --git a/lib/ansible/module_utils/csharp/Ansible.Basic.cs b/lib/ansible/module_utils/csharp/Ansible.Basic.cs index 5e4d7e5f6b9..7c0cc81e3c5 100644 --- a/lib/ansible/module_utils/csharp/Ansible.Basic.cs +++ b/lib/ansible/module_utils/csharp/Ansible.Basic.cs @@ -79,6 +79,7 @@ namespace Ansible.Basic { "socket", null }, { "syslog_facility", null }, { "target_log_info", "TargetLogInfo"}, + { "tracebacks_for", null}, { "tmpdir", "tmpdir" }, { "verbosity", "Verbosity" }, { "version", "AnsibleVersion" }, diff --git a/lib/ansible/module_utils/datatag.py b/lib/ansible/module_utils/datatag.py new file mode 100644 index 00000000000..0e182e3d042 --- /dev/null +++ b/lib/ansible/module_utils/datatag.py @@ -0,0 +1,46 @@ +"""Public API for data tagging.""" +from __future__ import annotations as _annotations + +import datetime as _datetime +import typing as _t + +from ._internal import _plugin_exec_context, _datatag +from ._internal._datatag import _tags + +_T = _t.TypeVar('_T') + + +def deprecate_value( + value: _T, + msg: str, + *, + help_text: str | None = None, + removal_date: str | _datetime.date | None = None, + removal_version: str | None = None, +) -> _T: + """ + Return `value` tagged with the given deprecation details. + The types `None` and `bool` cannot be deprecated and are returned unmodified. + Raises a `TypeError` if `value` is not a supported type. + If `removal_date` is a string, it must be in the form `YYYY-MM-DD`. + This function is only supported in contexts where an Ansible plugin/module is executing. + """ + if isinstance(removal_date, str): + # The `fromisoformat` method accepts other ISO 8601 formats than `YYYY-MM-DD` starting with Python 3.11. + # That should be considered undocumented behavior of `deprecate_value` rather than an intentional feature. + removal_date = _datetime.date.fromisoformat(removal_date) + + deprecated = _tags.Deprecated( + msg=msg, + help_text=help_text, + removal_date=removal_date, + removal_version=removal_version, + plugin=_plugin_exec_context.PluginExecContext.get_current_plugin_info(), + ) + + return deprecated.tag(value) + + +def native_type_name(value: object | type, /) -> str: + """Return the type name of `value`, substituting the native Python type name for internal tagged types.""" + return _datatag.AnsibleTagHelper.base_type(value).__name__ diff --git a/lib/ansible/module_utils/facts/ansible_collector.py b/lib/ansible/module_utils/facts/ansible_collector.py index 5b66f0a0eb3..82b6e16746b 100644 --- a/lib/ansible/module_utils/facts/ansible_collector.py +++ b/lib/ansible/module_utils/facts/ansible_collector.py @@ -30,8 +30,7 @@ from __future__ import annotations import fnmatch import sys - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts import timeout from ansible.module_utils.facts import collector diff --git a/lib/ansible/module_utils/facts/collector.py b/lib/ansible/module_utils/facts/collector.py index f3e144f7dda..6e5591f7de1 100644 --- a/lib/ansible/module_utils/facts/collector.py +++ b/lib/ansible/module_utils/facts/collector.py @@ -31,8 +31,7 @@ from __future__ import annotations from collections import defaultdict import platform - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts import timeout diff --git a/lib/ansible/module_utils/facts/default_collectors.py b/lib/ansible/module_utils/facts/default_collectors.py index af4391576c0..a1a92431919 100644 --- a/lib/ansible/module_utils/facts/default_collectors.py +++ b/lib/ansible/module_utils/facts/default_collectors.py @@ -27,7 +27,7 @@ # from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/hardware/base.py b/lib/ansible/module_utils/facts/hardware/base.py index 8710ed57fcc..75d6903924c 100644 --- a/lib/ansible/module_utils/facts/hardware/base.py +++ b/lib/ansible/module_utils/facts/hardware/base.py @@ -28,7 +28,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/network/base.py b/lib/ansible/module_utils/facts/network/base.py index 7e13e168b32..ae6f215735b 100644 --- a/lib/ansible/module_utils/facts/network/base.py +++ b/lib/ansible/module_utils/facts/network/base.py @@ -15,7 +15,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/network/fc_wwn.py b/lib/ansible/module_utils/facts/network/fc_wwn.py index fb846cc08a8..58f59806f1f 100644 --- a/lib/ansible/module_utils/facts/network/fc_wwn.py +++ b/lib/ansible/module_utils/facts/network/fc_wwn.py @@ -19,8 +19,7 @@ from __future__ import annotations import sys import glob - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_lines from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/network/iscsi.py b/lib/ansible/module_utils/facts/network/iscsi.py index 48f98a682bd..1ac48206055 100644 --- a/lib/ansible/module_utils/facts/network/iscsi.py +++ b/lib/ansible/module_utils/facts/network/iscsi.py @@ -18,8 +18,7 @@ from __future__ import annotations import sys - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content from ansible.module_utils.facts.network.base import NetworkCollector diff --git a/lib/ansible/module_utils/facts/network/nvme.py b/lib/ansible/module_utils/facts/network/nvme.py index 7eb070dcf5d..192f6f5275b 100644 --- a/lib/ansible/module_utils/facts/network/nvme.py +++ b/lib/ansible/module_utils/facts/network/nvme.py @@ -18,8 +18,7 @@ from __future__ import annotations import sys - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content from ansible.module_utils.facts.network.base import NetworkCollector diff --git a/lib/ansible/module_utils/facts/other/facter.py b/lib/ansible/module_utils/facts/other/facter.py index 41b3cea7c92..f050e2ca605 100644 --- a/lib/ansible/module_utils/facts/other/facter.py +++ b/lib/ansible/module_utils/facts/other/facter.py @@ -4,8 +4,7 @@ from __future__ import annotations import json - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.namespace import PrefixFactNamespace from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/other/ohai.py b/lib/ansible/module_utils/facts/other/ohai.py index db62fe4d73e..4cb2f7a2f0b 100644 --- a/lib/ansible/module_utils/facts/other/ohai.py +++ b/lib/ansible/module_utils/facts/other/ohai.py @@ -16,8 +16,7 @@ from __future__ import annotations import json - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.namespace import PrefixFactNamespace diff --git a/lib/ansible/module_utils/facts/system/apparmor.py b/lib/ansible/module_utils/facts/system/apparmor.py index ec29e883e09..d0ead37d34d 100644 --- a/lib/ansible/module_utils/facts/system/apparmor.py +++ b/lib/ansible/module_utils/facts/system/apparmor.py @@ -18,8 +18,7 @@ from __future__ import annotations import os - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/caps.py b/lib/ansible/module_utils/facts/system/caps.py index 365a04592ac..decd754233b 100644 --- a/lib/ansible/module_utils/facts/system/caps.py +++ b/lib/ansible/module_utils/facts/system/caps.py @@ -17,7 +17,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/chroot.py b/lib/ansible/module_utils/facts/system/chroot.py index bbf4b39dd3e..85c7a4288c5 100644 --- a/lib/ansible/module_utils/facts/system/chroot.py +++ b/lib/ansible/module_utils/facts/system/chroot.py @@ -3,8 +3,7 @@ from __future__ import annotations import os - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/cmdline.py b/lib/ansible/module_utils/facts/system/cmdline.py index 12376dc0ba1..dc4b8d08256 100644 --- a/lib/ansible/module_utils/facts/system/cmdline.py +++ b/lib/ansible/module_utils/facts/system/cmdline.py @@ -16,8 +16,7 @@ from __future__ import annotations import shlex - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/date_time.py b/lib/ansible/module_utils/facts/system/date_time.py index 1cef95077be..21b97bce773 100644 --- a/lib/ansible/module_utils/facts/system/date_time.py +++ b/lib/ansible/module_utils/facts/system/date_time.py @@ -19,8 +19,8 @@ from __future__ import annotations import datetime import time +import typing as t -import ansible.module_utils.compat.typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/distribution.py b/lib/ansible/module_utils/facts/system/distribution.py index bd9dacd438f..fff2bce4cf1 100644 --- a/lib/ansible/module_utils/facts/system/distribution.py +++ b/lib/ansible/module_utils/facts/system/distribution.py @@ -8,8 +8,7 @@ from __future__ import annotations import os import platform import re - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.common.sys_info import get_distribution, get_distribution_version, \ get_distribution_codename @@ -208,7 +207,7 @@ class DistributionFiles: return dist_file_facts - # TODO: FIXME: split distro file parsing into its own module or class + # FIXME: split distro file parsing into its own module or class def parse_distribution_file_Slackware(self, name, data, path, collected_facts): slackware_facts = {} if 'Slackware' not in data: diff --git a/lib/ansible/module_utils/facts/system/dns.py b/lib/ansible/module_utils/facts/system/dns.py index 7ef69d136fc..5da8e5ba351 100644 --- a/lib/ansible/module_utils/facts/system/dns.py +++ b/lib/ansible/module_utils/facts/system/dns.py @@ -15,7 +15,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/env.py b/lib/ansible/module_utils/facts/system/env.py index 4547924532e..cf6a22457a9 100644 --- a/lib/ansible/module_utils/facts/system/env.py +++ b/lib/ansible/module_utils/facts/system/env.py @@ -16,8 +16,7 @@ from __future__ import annotations import os - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.six import iteritems diff --git a/lib/ansible/module_utils/facts/system/fips.py b/lib/ansible/module_utils/facts/system/fips.py index 131434157d4..36b0a37f0c7 100644 --- a/lib/ansible/module_utils/facts/system/fips.py +++ b/lib/ansible/module_utils/facts/system/fips.py @@ -4,7 +4,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/loadavg.py b/lib/ansible/module_utils/facts/system/loadavg.py index 37cb554434f..3433c06ee34 100644 --- a/lib/ansible/module_utils/facts/system/loadavg.py +++ b/lib/ansible/module_utils/facts/system/loadavg.py @@ -4,8 +4,7 @@ from __future__ import annotations import os - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/local.py b/lib/ansible/module_utils/facts/system/local.py index 66ec58a2e7d..09d0e18a6d0 100644 --- a/lib/ansible/module_utils/facts/system/local.py +++ b/lib/ansible/module_utils/facts/system/local.py @@ -7,8 +7,7 @@ import glob import json import os import stat - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/lsb.py b/lib/ansible/module_utils/facts/system/lsb.py index 5767536b1d7..93251c31087 100644 --- a/lib/ansible/module_utils/facts/system/lsb.py +++ b/lib/ansible/module_utils/facts/system/lsb.py @@ -18,8 +18,7 @@ from __future__ import annotations import os - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_lines from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/pkg_mgr.py b/lib/ansible/module_utils/facts/system/pkg_mgr.py index e9da18647b8..baa07076b8a 100644 --- a/lib/ansible/module_utils/facts/system/pkg_mgr.py +++ b/lib/ansible/module_utils/facts/system/pkg_mgr.py @@ -6,8 +6,7 @@ from __future__ import annotations import os import subprocess - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/platform.py b/lib/ansible/module_utils/facts/system/platform.py index 94819861b4b..cd9f11cdb37 100644 --- a/lib/ansible/module_utils/facts/system/platform.py +++ b/lib/ansible/module_utils/facts/system/platform.py @@ -18,8 +18,7 @@ from __future__ import annotations import re import socket import platform - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/python.py b/lib/ansible/module_utils/facts/system/python.py index 0252c0c96a7..b75d32974e6 100644 --- a/lib/ansible/module_utils/facts/system/python.py +++ b/lib/ansible/module_utils/facts/system/python.py @@ -16,8 +16,7 @@ from __future__ import annotations import sys - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/selinux.py b/lib/ansible/module_utils/facts/system/selinux.py index c110f17e720..1e5ea81ac78 100644 --- a/lib/ansible/module_utils/facts/system/selinux.py +++ b/lib/ansible/module_utils/facts/system/selinux.py @@ -17,7 +17,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/system/service_mgr.py b/lib/ansible/module_utils/facts/system/service_mgr.py index 20257967c1e..ba798e09dfb 100644 --- a/lib/ansible/module_utils/facts/system/service_mgr.py +++ b/lib/ansible/module_utils/facts/system/service_mgr.py @@ -20,8 +20,7 @@ from __future__ import annotations import os import platform import re - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.common.text.converters import to_native diff --git a/lib/ansible/module_utils/facts/system/ssh_pub_keys.py b/lib/ansible/module_utils/facts/system/ssh_pub_keys.py index 7214dea3de6..295ea135b11 100644 --- a/lib/ansible/module_utils/facts/system/ssh_pub_keys.py +++ b/lib/ansible/module_utils/facts/system/ssh_pub_keys.py @@ -15,7 +15,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.utils import get_file_content diff --git a/lib/ansible/module_utils/facts/system/systemd.py b/lib/ansible/module_utils/facts/system/systemd.py index 3ba2bbfcbdf..cb6f4c7931d 100644 --- a/lib/ansible/module_utils/facts/system/systemd.py +++ b/lib/ansible/module_utils/facts/system/systemd.py @@ -17,7 +17,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector from ansible.module_utils.facts.system.service_mgr import ServiceMgrFactCollector diff --git a/lib/ansible/module_utils/facts/system/user.py b/lib/ansible/module_utils/facts/system/user.py index 64b8fef8be6..cbfd37348eb 100644 --- a/lib/ansible/module_utils/facts/system/user.py +++ b/lib/ansible/module_utils/facts/system/user.py @@ -18,8 +18,7 @@ from __future__ import annotations import getpass import os import pwd - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/facts/virtual/base.py b/lib/ansible/module_utils/facts/virtual/base.py index 943ce406d86..f03e2289180 100644 --- a/lib/ansible/module_utils/facts/virtual/base.py +++ b/lib/ansible/module_utils/facts/virtual/base.py @@ -18,7 +18,7 @@ from __future__ import annotations -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.facts.collector import BaseFactCollector diff --git a/lib/ansible/module_utils/parsing/convert_bool.py b/lib/ansible/module_utils/parsing/convert_bool.py index 3367b2a09fa..594ede436f2 100644 --- a/lib/ansible/module_utils/parsing/convert_bool.py +++ b/lib/ansible/module_utils/parsing/convert_bool.py @@ -25,4 +25,4 @@ def boolean(value, strict=True): elif normalized_value in BOOLEANS_FALSE or not strict: return False - raise TypeError("The value '%s' is not a valid boolean. Valid booleans include: %s" % (to_text(value), ', '.join(repr(i) for i in BOOLEANS))) + raise TypeError("The value '%s' is not a valid boolean. Valid booleans include: %s" % (to_text(value), ', '.join(repr(i) for i in BOOLEANS))) diff --git a/lib/ansible/module_utils/service.py b/lib/ansible/module_utils/service.py index 6d3ecea4b8d..06ae8392a83 100644 --- a/lib/ansible/module_utils/service.py +++ b/lib/ansible/module_utils/service.py @@ -35,7 +35,6 @@ import platform import select import shlex import subprocess -import traceback from ansible.module_utils.six import PY2, b from ansible.module_utils.common.text.converters import to_bytes, to_text @@ -180,7 +179,9 @@ def daemonize(module, cmd): pipe = os.pipe() pid = fork_process() except (OSError, RuntimeError): - module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc()) + module.fail_json(msg="Error while attempting to fork.") + except Exception as exc: + module.fail_json(msg=to_text(exc)) # we don't do any locking as this should be a unique module/process if pid == 0: diff --git a/lib/ansible/module_utils/testing.py b/lib/ansible/module_utils/testing.py new file mode 100644 index 00000000000..4f2ed9435a7 --- /dev/null +++ b/lib/ansible/module_utils/testing.py @@ -0,0 +1,31 @@ +""" +Utilities to support unit testing of Ansible Python modules. +Not supported for use cases other than testing. +""" + +from __future__ import annotations as _annotations + +import contextlib as _contextlib +import json as _json +import typing as _t + +from unittest import mock as _mock + +from ansible.module_utils.common import json as _common_json +from . import basic as _basic + + +@_contextlib.contextmanager +def patch_module_args(args: dict[str, _t.Any] | None = None) -> _t.Iterator[None]: + """Expose the given module args to `AnsibleModule` instances created within this context.""" + if not isinstance(args, (dict, type(None))): + raise TypeError("The `args` arg must be a dict or None.") + + args = dict(ANSIBLE_MODULE_ARGS=args or {}) + profile = 'legacy' # this should be configurable in the future, once the profile feature is more fully baked + + encoder = _common_json.get_module_encoder(profile, _common_json.Direction.CONTROLLER_TO_MODULE) + args = _json.dumps(args, cls=encoder).encode() + + with _mock.patch.object(_basic, '_ANSIBLE_ARGS', args), _mock.patch.object(_basic, '_ANSIBLE_PROFILE', profile): + yield diff --git a/lib/ansible/module_utils/urls.py b/lib/ansible/module_utils/urls.py index 09ea835d720..423f077104d 100644 --- a/lib/ansible/module_utils/urls.py +++ b/lib/ansible/module_utils/urls.py @@ -1198,7 +1198,7 @@ def fetch_url(module, url, data=None, headers=None, method=None, data={...} resp, info = fetch_url(module, "http://example.com", - data=module.jsonify(data), + data=json.dumps(data), headers={'Content-type': 'application/json'}, method="POST") status_code = info["status"] @@ -1276,7 +1276,7 @@ def fetch_url(module, url, data=None, headers=None, method=None, except (ConnectionError, ValueError) as e: module.fail_json(msg=to_native(e), **info) except MissingModuleError as e: - module.fail_json(msg=to_text(e), exception=e.import_traceback) + module.fail_json(msg=to_text(e)) except urllib.error.HTTPError as e: r = e try: @@ -1307,9 +1307,8 @@ def fetch_url(module, url, data=None, headers=None, method=None, info.update(dict(msg="Connection failure: %s" % to_native(e), status=-1)) except http.client.BadStatusLine as e: info.update(dict(msg="Connection failure: connection was closed before a valid response was received: %s" % to_native(e.line), status=-1)) - except Exception as e: - info.update(dict(msg="An unknown error occurred: %s" % to_native(e), status=-1), - exception=traceback.format_exc()) + except Exception as ex: + info.update(dict(msg="An unknown error occurred: %s" % to_native(ex), status=-1, exception=traceback.format_exc())) finally: tempfile.tempdir = old_tempdir diff --git a/lib/ansible/modules/apt_key.py b/lib/ansible/modules/apt_key.py index 03484c5f091..06648041e32 100644 --- a/lib/ansible/modules/apt_key.py +++ b/lib/ansible/modules/apt_key.py @@ -172,8 +172,6 @@ short_id: import os -from traceback import format_exc - from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.locale import get_best_parsable_locale @@ -319,7 +317,7 @@ def download_key(module, url): return rsp.read() except Exception: - module.fail_json(msg="error getting key id from url: %s" % url, traceback=format_exc()) + module.fail_json(msg=f"Error getting key id from url: {url}") def get_key_id_from_file(module, filename, data=None): diff --git a/lib/ansible/modules/async_status.py b/lib/ansible/modules/async_status.py index 0a4eeb53ac2..6459f10c02f 100644 --- a/lib/ansible/modules/async_status.py +++ b/lib/ansible/modules/async_status.py @@ -5,6 +5,7 @@ from __future__ import annotations +import sys DOCUMENTATION = r""" --- @@ -111,8 +112,6 @@ import json import os from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.six import iteritems -from ansible.module_utils.common.text.converters import to_native def main(): @@ -163,10 +162,9 @@ def main(): elif 'finished' not in data: data['finished'] = 0 - # Fix error: TypeError: exit_json() keywords must be strings - data = {to_native(k): v for k, v in iteritems(data)} - - module.exit_json(**data) + # just write the module output directly to stdout and exit; bypass other processing done by exit_json since it's already been done + print(f"\n{json.dumps(data)}") # pylint: disable=ansible-bad-function + sys.exit(0) # pylint: disable=ansible-bad-function if __name__ == '__main__': diff --git a/lib/ansible/modules/async_wrapper.py b/lib/ansible/modules/async_wrapper.py index d33ebe196ed..7c2fb257f38 100644 --- a/lib/ansible/modules/async_wrapper.py +++ b/lib/ansible/modules/async_wrapper.py @@ -147,6 +147,8 @@ def jwrite(info): def _run_module(wrapped_cmd, jid): + # DTFIX-FUTURE: needs rework for serialization profiles + jwrite({"started": 1, "finished": 0, "ansible_job_id": jid}) result = {} @@ -188,6 +190,9 @@ def _run_module(wrapped_cmd, jid): module_warnings = result.get('warnings', []) if not isinstance(module_warnings, list): module_warnings = [module_warnings] + + # this relies on the controller's fallback conversion of string warnings to WarningMessageDetail instances, and assumes + # that the module result and warning collection are basic JSON datatypes (eg, no tags or other custom collections). module_warnings.extend(json_warnings) result['warnings'] = module_warnings @@ -257,7 +262,7 @@ def main(): end({ "failed": 1, "msg": "could not create directory: %s - %s" % (jobdir, to_text(e)), - "exception": to_text(traceback.format_exc()), + "exception": to_text(traceback.format_exc()), # NB: task executor compat will coerce to the correct dataclass type }, 1) # immediately exit this process, leaving an orphaned process diff --git a/lib/ansible/modules/command.py b/lib/ansible/modules/command.py index ed71342ab6b..fa2415d73d2 100644 --- a/lib/ansible/modules/command.py +++ b/lib/ansible/modules/command.py @@ -249,6 +249,7 @@ def main(): argument_spec=dict( _raw_params=dict(), _uses_shell=dict(type='bool', default=False), + cmd=dict(), argv=dict(type='list', elements='str'), chdir=dict(type='path'), executable=dict(), @@ -260,12 +261,14 @@ def main(): stdin_add_newline=dict(type='bool', default=True), strip_empty_ends=dict(type='bool', default=True), ), + required_one_of=[['_raw_params', 'cmd', 'argv']], + mutually_exclusive=[['_raw_params', 'cmd', 'argv']], supports_check_mode=True, ) shell = module.params['_uses_shell'] chdir = module.params['chdir'] executable = module.params['executable'] - args = module.params['_raw_params'] + args = module.params['_raw_params'] or module.params['cmd'] argv = module.params['argv'] creates = module.params['creates'] removes = module.params['removes'] @@ -281,16 +284,6 @@ def main(): module.warn("As of Ansible 2.4, the parameter 'executable' is no longer supported with the 'command' module. Not using '%s'." % executable) executable = None - if (not args or args.strip() == '') and not argv: - r['rc'] = 256 - r['msg'] = "no command given" - module.fail_json(**r) - - if args and argv: - r['rc'] = 256 - r['msg'] = "only command or argv can be given, not both" - module.fail_json(**r) - if not shell and args: args = shlex.split(args) diff --git a/lib/ansible/modules/copy.py b/lib/ansible/modules/copy.py index fc904ae2768..0e052f76f18 100644 --- a/lib/ansible/modules/copy.py +++ b/lib/ansible/modules/copy.py @@ -291,7 +291,6 @@ import os.path import shutil import stat import tempfile -import traceback from ansible.module_utils.common.text.converters import to_bytes, to_native from ansible.module_utils.basic import AnsibleModule @@ -638,7 +637,7 @@ def main(): module.atomic_move(b_mysrc, dest, unsafe_writes=module.params['unsafe_writes'], keep_dest_attrs=not remote_src) except (IOError, OSError): - module.fail_json(msg="failed to copy: %s to %s" % (src, dest), traceback=traceback.format_exc()) + module.fail_json(msg=f"Failed to copy {src!r} to {dest!r}.") changed = True # If neither have checksums, both src and dest are directories. diff --git a/lib/ansible/modules/cron.py b/lib/ansible/modules/cron.py index 7ee12fe8f82..8abfca172fa 100644 --- a/lib/ansible/modules/cron.py +++ b/lib/ansible/modules/cron.py @@ -277,7 +277,7 @@ class CronTab(object): except Exception: raise CronTabError("Unexpected error:", sys.exc_info()[0]) else: - # using safely quoted shell for now, but this really should be two non-shell calls instead. FIXME + # FIXME: using safely quoted shell for now, but this really should be two non-shell calls instead. (rc, out, err) = self.module.run_command(self._read_user_execute(), use_unsafe_shell=True) if rc != 0 and rc != 1: # 1 can mean that there are no jobs. @@ -328,7 +328,7 @@ class CronTab(object): # Add the entire crontab back to the user crontab if not self.cron_file: - # quoting shell args for now but really this should be two non-shell calls. FIXME + # FIXME: quoting shell args for now but really this should be two non-shell calls. (rc, out, err) = self.module.run_command(self._write_execute(path), use_unsafe_shell=True) os.unlink(path) diff --git a/lib/ansible/modules/deb822_repository.py b/lib/ansible/modules/deb822_repository.py index a27af10786c..d4d6205511e 100644 --- a/lib/ansible/modules/deb822_repository.py +++ b/lib/ansible/modules/deb822_repository.py @@ -230,7 +230,6 @@ import os import re import tempfile import textwrap -import traceback from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import missing_required_lib @@ -248,9 +247,9 @@ HAS_DEBIAN = True DEBIAN_IMP_ERR = None try: from debian.deb822 import Deb822 # type: ignore[import] -except ImportError: +except ImportError as ex: HAS_DEBIAN = False - DEBIAN_IMP_ERR = traceback.format_exc() + DEBIAN_IMP_ERR = ex KEYRINGS_DIR = '/etc/apt/keyrings' diff --git a/lib/ansible/modules/dnf.py b/lib/ansible/modules/dnf.py index 7ab874a941f..07f0384b5c9 100644 --- a/lib/ansible/modules/dnf.py +++ b/lib/ansible/modules/dnf.py @@ -408,10 +408,10 @@ from ansible.module_utils.common.respawn import has_respawned, probe_interpreter from ansible.module_utils.yumdnf import YumDnf, yumdnf_argument_spec -# NOTE dnf Python bindings import is postponed, see DnfModule._ensure_dnf(), -# because we need AnsibleModule object to use get_best_parsable_locale() -# to set proper locale before importing dnf to be able to scrape -# the output in some cases (FIXME?). +# FIXME: NOTE dnf Python bindings import is postponed, see DnfModule._ensure_dnf(), +# because we need AnsibleModule object to use get_best_parsable_locale() +# to set proper locale before importing dnf to be able to scrape +# the output in some cases. dnf = None diff --git a/lib/ansible/modules/expect.py b/lib/ansible/modules/expect.py index 90ece7d76f3..1436e231b57 100644 --- a/lib/ansible/modules/expect.py +++ b/lib/ansible/modules/expect.py @@ -120,14 +120,13 @@ EXAMPLES = r""" import datetime import os -import traceback PEXPECT_IMP_ERR = None try: import pexpect HAS_PEXPECT = True -except ImportError: - PEXPECT_IMP_ERR = traceback.format_exc() +except ImportError as ex: + PEXPECT_IMP_ERR = ex HAS_PEXPECT = False from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -164,8 +163,7 @@ def main(): ) if not HAS_PEXPECT: - module.fail_json(msg=missing_required_lib("pexpect"), - exception=PEXPECT_IMP_ERR) + module.fail_json(msg=missing_required_lib("pexpect"), exception=PEXPECT_IMP_ERR) chdir = module.params['chdir'] args = module.params['command'] @@ -246,7 +244,7 @@ def main(): '(%s), this module requires pexpect>=3.3. ' 'Error was %s' % (pexpect.__version__, to_native(e))) except pexpect.ExceptionPexpect as e: - module.fail_json(msg='%s' % to_native(e), exception=traceback.format_exc()) + module.fail_json(msg='%s' % to_native(e)) endd = datetime.datetime.now() delta = endd - startd diff --git a/lib/ansible/modules/file.py b/lib/ansible/modules/file.py index b79eca58881..62a191de49e 100644 --- a/lib/ansible/modules/file.py +++ b/lib/ansible/modules/file.py @@ -244,7 +244,19 @@ from ansible.module_utils.common.sentinel import Sentinel module = None -def additional_parameter_handling(module): +class AnsibleModuleError(Exception): + def __init__(self, results): + self.results = results + + def __repr__(self): + return 'AnsibleModuleError(results={0})'.format(self.results) + + +class ParameterError(AnsibleModuleError): + pass + + +def additional_parameter_handling(params): """Additional parameter validation and reformatting""" # When path is a directory, rewrite the pathname to be the file inside of the directory # TODO: Why do we exclude link? Why don't we exclude directory? Should we exclude touch? @@ -256,7 +268,6 @@ def additional_parameter_handling(module): # if state == file: place inside of the directory (use _original_basename) # if state == link: place inside of the directory (use _original_basename. Fallback to src?) # if state == hard: place inside of the directory (use _original_basename. Fallback to src?) - params = module.params if (params['state'] not in ("link", "absent") and os.path.isdir(to_bytes(params['path'], errors='surrogate_or_strict'))): basename = None @@ -966,46 +977,49 @@ def main(): supports_check_mode=True, ) - additional_parameter_handling(module) - params = module.params - - state = params['state'] - recurse = params['recurse'] - force = params['force'] - follow = params['follow'] - path = params['path'] - src = params['src'] - - if module.check_mode and state != 'absent': - file_args = module.load_file_common_arguments(module.params) - if file_args['owner']: - check_owner_exists(module, file_args['owner']) - if file_args['group']: - check_group_exists(module, file_args['group']) - - timestamps = {} - timestamps['modification_time'] = keep_backward_compatibility_on_timestamps(params['modification_time'], state) - timestamps['modification_time_format'] = params['modification_time_format'] - timestamps['access_time'] = keep_backward_compatibility_on_timestamps(params['access_time'], state) - timestamps['access_time_format'] = params['access_time_format'] - - # short-circuit for diff_peek - if params['_diff_peek'] is not None: - appears_binary = execute_diff_peek(to_bytes(path, errors='surrogate_or_strict')) - module.exit_json(path=path, changed=False, appears_binary=appears_binary) - - if state == 'file': - result = ensure_file_attributes(path, follow, timestamps) - elif state == 'directory': - result = ensure_directory(path, follow, recurse, timestamps) - elif state == 'link': - result = ensure_symlink(path, src, follow, force, timestamps) - elif state == 'hard': - result = ensure_hardlink(path, src, follow, force, timestamps) - elif state == 'touch': - result = execute_touch(path, follow, timestamps) - elif state == 'absent': - result = ensure_absent(path) + try: + additional_parameter_handling(module.params) + params = module.params + + state = params['state'] + recurse = params['recurse'] + force = params['force'] + follow = params['follow'] + path = params['path'] + src = params['src'] + + if module.check_mode and state != 'absent': + file_args = module.load_file_common_arguments(module.params) + if file_args['owner']: + check_owner_exists(module, file_args['owner']) + if file_args['group']: + check_group_exists(module, file_args['group']) + + timestamps = {} + timestamps['modification_time'] = keep_backward_compatibility_on_timestamps(params['modification_time'], state) + timestamps['modification_time_format'] = params['modification_time_format'] + timestamps['access_time'] = keep_backward_compatibility_on_timestamps(params['access_time'], state) + timestamps['access_time_format'] = params['access_time_format'] + + # short-circuit for diff_peek + if params['_diff_peek'] is not None: + appears_binary = execute_diff_peek(to_bytes(path, errors='surrogate_or_strict')) + module.exit_json(path=path, changed=False, appears_binary=appears_binary) + + if state == 'file': + result = ensure_file_attributes(path, follow, timestamps) + elif state == 'directory': + result = ensure_directory(path, follow, recurse, timestamps) + elif state == 'link': + result = ensure_symlink(path, src, follow, force, timestamps) + elif state == 'hard': + result = ensure_hardlink(path, src, follow, force, timestamps) + elif state == 'touch': + result = execute_touch(path, follow, timestamps) + elif state == 'absent': + result = ensure_absent(path) + except AnsibleModuleError as ex: + module.fail_json(**ex.results) if not module._diff: result.pop('diff', None) diff --git a/lib/ansible/modules/get_url.py b/lib/ansible/modules/get_url.py index a794a609346..f742c363349 100644 --- a/lib/ansible/modules/get_url.py +++ b/lib/ansible/modules/get_url.py @@ -372,7 +372,7 @@ import os import re import shutil import tempfile -import traceback + from datetime import datetime, timezone from ansible.module_utils.basic import AnsibleModule @@ -433,7 +433,7 @@ def url_get(module, url, dest, use_proxy, last_mod_time, force, timeout=10, head shutil.copyfileobj(rsp, f) except Exception as e: os.remove(tempname) - module.fail_json(msg="failed to create temporary content file: %s" % to_native(e), elapsed=elapsed, exception=traceback.format_exc()) + module.fail_json(msg="failed to create temporary content file: %s" % to_native(e), elapsed=elapsed) f.close() rsp.close() return tempname, info @@ -690,8 +690,7 @@ def main(): except Exception as e: if os.path.exists(tmpsrc): os.remove(tmpsrc) - module.fail_json(msg="failed to copy %s to %s: %s" % (tmpsrc, dest, to_native(e)), - exception=traceback.format_exc(), **result) + module.fail_json(msg="failed to copy %s to %s: %s" % (tmpsrc, dest, to_native(e)), **result) result['changed'] = True else: result['changed'] = False diff --git a/lib/ansible/modules/getent.py b/lib/ansible/modules/getent.py index 1938af1fcfa..e195b7ef7ea 100644 --- a/lib/ansible/modules/getent.py +++ b/lib/ansible/modules/getent.py @@ -114,8 +114,6 @@ ansible_facts: type: list """ -import traceback - from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_native @@ -156,7 +154,7 @@ def main(): try: rc, out, err = module.run_command(cmd) except Exception as e: - module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + module.fail_json(msg=to_native(e)) msg = "Unexpected failure!" dbtree = 'getent_%s' % database diff --git a/lib/ansible/modules/hostname.py b/lib/ansible/modules/hostname.py index 79f9bcb0709..63bbea4a7ce 100644 --- a/lib/ansible/modules/hostname.py +++ b/lib/ansible/modules/hostname.py @@ -68,9 +68,7 @@ EXAMPLES = """ import os import platform import socket -import traceback - -import ansible.module_utils.compat.typing as t +import typing as t from ansible.module_utils.basic import ( AnsibleModule, @@ -209,17 +207,14 @@ class FileStrategy(BaseStrategy): return get_file_content(self.FILE, default='', strip=True) except Exception as e: self.module.fail_json( - msg="failed to read hostname: %s" % to_native(e), - exception=traceback.format_exc()) + msg="failed to read hostname: %s" % to_native(e)) def set_permanent_hostname(self, name): try: with open(self.FILE, 'w+') as f: f.write("%s\n" % name) except Exception as e: - self.module.fail_json( - msg="failed to update hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to update hostname: %s" % to_native(e)) class SLESStrategy(FileStrategy): @@ -249,8 +244,7 @@ class RedHatStrategy(BaseStrategy): ) except Exception as e: self.module.fail_json( - msg="failed to read hostname: %s" % to_native(e), - exception=traceback.format_exc()) + msg="failed to read hostname: %s" % to_native(e)) def set_permanent_hostname(self, name): try: @@ -269,9 +263,7 @@ class RedHatStrategy(BaseStrategy): with open(self.NETWORK_FILE, 'w+') as f: f.writelines(lines) except Exception as e: - self.module.fail_json( - msg="failed to update hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to update hostname: %s" % to_native(e)) class AlpineStrategy(FileStrategy): @@ -361,9 +353,7 @@ class OpenRCStrategy(BaseStrategy): if line.startswith('hostname='): return line[10:].strip('"') except Exception as e: - self.module.fail_json( - msg="failed to read hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to read hostname: %s" % to_native(e)) def set_permanent_hostname(self, name): try: @@ -377,9 +367,7 @@ class OpenRCStrategy(BaseStrategy): with open(self.FILE, 'w') as f: f.write('\n'.join(lines) + '\n') except Exception as e: - self.module.fail_json( - msg="failed to update hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to update hostname: %s" % to_native(e)) class OpenBSDStrategy(FileStrategy): @@ -481,9 +469,7 @@ class FreeBSDStrategy(BaseStrategy): if line.startswith('hostname='): return line[10:].strip('"') except Exception as e: - self.module.fail_json( - msg="failed to read hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to read hostname: %s" % to_native(e)) def set_permanent_hostname(self, name): try: @@ -500,9 +486,7 @@ class FreeBSDStrategy(BaseStrategy): with open(self.FILE, 'w') as f: f.write('\n'.join(lines) + '\n') except Exception as e: - self.module.fail_json( - msg="failed to update hostname: %s" % to_native(e), - exception=traceback.format_exc()) + self.module.fail_json(msg="failed to update hostname: %s" % to_native(e)) class DarwinStrategy(BaseStrategy): diff --git a/lib/ansible/modules/pip.py b/lib/ansible/modules/pip.py index 028ef3f6e3b..2d520618f12 100644 --- a/lib/ansible/modules/pip.py +++ b/lib/ansible/modules/pip.py @@ -299,7 +299,6 @@ import sys import tempfile import operator import shlex -import traceback from ansible.module_utils.compat.version import LooseVersion @@ -309,10 +308,10 @@ HAS_SETUPTOOLS = False try: from packaging.requirements import Requirement as parse_requirement HAS_PACKAGING = True -except Exception: +except Exception as ex: # This is catching a generic Exception, due to packaging on EL7 raising a TypeError on import HAS_PACKAGING = False - PACKAGING_IMP_ERR = traceback.format_exc() + PACKAGING_IMP_ERR = ex try: from pkg_resources import Requirement parse_requirement = Requirement.parse # type: ignore[misc,assignment] diff --git a/lib/ansible/modules/replace.py b/lib/ansible/modules/replace.py index 61e629b26a0..980c0fbdf4a 100644 --- a/lib/ansible/modules/replace.py +++ b/lib/ansible/modules/replace.py @@ -182,7 +182,6 @@ RETURN = r"""#""" import os import re import tempfile -from traceback import format_exc from ansible.module_utils.common.text.converters import to_text, to_bytes from ansible.module_utils.basic import AnsibleModule @@ -258,8 +257,7 @@ def main(): with open(path, 'rb') as f: contents = to_text(f.read(), errors='surrogate_or_strict', encoding=encoding) except (OSError, IOError) as e: - module.fail_json(msg='Unable to read the contents of %s: %s' % (path, to_text(e)), - exception=format_exc()) + module.fail_json(msg='Unable to read the contents of %s: %s' % (path, to_text(e))) pattern = u'' if params['after'] and params['before']: @@ -286,8 +284,7 @@ def main(): try: result = re.subn(mre, params['replace'], section, 0) except re.error as e: - module.fail_json(msg="Unable to process replace due to error: %s" % to_text(e), - exception=format_exc()) + module.fail_json(msg="Unable to process replace due to error: %s" % to_text(e)) if result[1] > 0 and section != result[0]: if pattern: diff --git a/lib/ansible/modules/set_fact.py b/lib/ansible/modules/set_fact.py index ef4989c44fa..29fef156886 100644 --- a/lib/ansible/modules/set_fact.py +++ b/lib/ansible/modules/set_fact.py @@ -66,7 +66,7 @@ notes: - Because of the nature of tasks, set_fact will produce 'static' values for a variable. Unlike normal 'lazy' variables, the value gets evaluated and templated on assignment. - Some boolean values (yes, no, true, false) will always be converted to boolean type, - unless C(DEFAULT_JINJA2_NATIVE) is enabled. This is done so the C(var=value) booleans, + This is done so the C(var=value) booleans, otherwise it would only be able to create strings, but it also prevents using those values to create YAML strings. Using the setting will restrict k=v to strings, but will allow you to specify string or boolean in YAML. - "To create lists/arrays or dictionary/hashes use YAML notation C(var: [val1, val2])." diff --git a/lib/ansible/modules/tempfile.py b/lib/ansible/modules/tempfile.py index a9a8d644300..a7163b02ebf 100644 --- a/lib/ansible/modules/tempfile.py +++ b/lib/ansible/modules/tempfile.py @@ -90,7 +90,6 @@ path: from os import close from tempfile import mkstemp, mkdtemp -from traceback import format_exc from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.text.converters import to_native @@ -123,7 +122,7 @@ def main(): module.exit_json(changed=True, path=path) except Exception as e: - module.fail_json(msg=to_native(e), exception=format_exc()) + module.fail_json(msg=to_native(e)) if __name__ == '__main__': diff --git a/lib/ansible/modules/unarchive.py b/lib/ansible/modules/unarchive.py index b317dbc737e..06bf9edc865 100644 --- a/lib/ansible/modules/unarchive.py +++ b/lib/ansible/modules/unarchive.py @@ -250,7 +250,6 @@ import pwd import re import stat import time -import traceback from functools import partial from zipfile import ZipFile @@ -698,7 +697,7 @@ class ZipArchive(object): try: mode = AnsibleModule._symbolic_mode_to_octal(st, self.file_args['mode']) except ValueError as e: - self.module.fail_json(path=path, msg="%s" % to_native(e), exception=traceback.format_exc()) + self.module.fail_json(path=path, msg="%s" % to_native(e)) # Only special files require no umask-handling elif ztype == '?': mode = self._permstr_to_octal(permstr, 0) diff --git a/lib/ansible/modules/user.py b/lib/ansible/modules/user.py index 90ecd04b8d9..ff990b07b5d 100644 --- a/lib/ansible/modules/user.py +++ b/lib/ansible/modules/user.py @@ -503,13 +503,13 @@ import socket import subprocess import time import math +import typing as t from ansible.module_utils import distro from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.locale import get_best_parsable_locale from ansible.module_utils.common.sys_info import get_platform_subclass -import ansible.module_utils.compat.typing as t class StructSpwdType(ctypes.Structure): diff --git a/lib/ansible/modules/wait_for.py b/lib/ansible/modules/wait_for.py index 7faff8389a5..49bc7cde63c 100644 --- a/lib/ansible/modules/wait_for.py +++ b/lib/ansible/modules/wait_for.py @@ -234,7 +234,7 @@ import re import select import socket import time -import traceback + from datetime import datetime, timedelta, timezone from ansible.module_utils.basic import AnsibleModule, missing_required_lib @@ -248,8 +248,8 @@ try: import psutil HAS_PSUTIL = True # just because we can import it on Linux doesn't mean we will use it -except ImportError: - PSUTIL_IMP_ERR = traceback.format_exc() +except ImportError as ex: + PSUTIL_IMP_ERR = ex class TCPConnectionInfo(object): @@ -616,7 +616,7 @@ def main(): _timedelta_total_seconds(end - datetime.now(timezone.utc)), ) try: - s = socket.create_connection((host, port), min(connect_timeout, alt_connect_timeout)) + s = socket.create_connection((host, int(port)), min(connect_timeout, alt_connect_timeout)) except Exception: # Failed to connect by connect_timeout. wait and try again pass diff --git a/lib/ansible/parsing/ajson.py b/lib/ansible/parsing/ajson.py index ff29240afc1..cfa5f7c217e 100644 --- a/lib/ansible/parsing/ajson.py +++ b/lib/ansible/parsing/ajson.py @@ -1,40 +1,22 @@ # Copyright: (c) 2018, Ansible Project # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import annotations +from __future__ import annotations as _annotations -import json +# from ansible.utils.display import Display as _Display -# Imported for backwards compat -from ansible.module_utils.common.json import AnsibleJSONEncoder # pylint: disable=unused-import -from ansible.parsing.vault import VaultLib -from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode -from ansible.utils.unsafe_proxy import wrap_var +# DTFIX-RELEASE: The pylint deprecated checker does not detect `Display().deprecated` calls, of which we have many. +# deprecated: description='deprecate ajson' core_version='2.23' +# _Display().deprecated( +# msg='The `ansible.parsing.ajson` module is deprecated.', +# version='2.27', +# help_text="", # DTFIX-RELEASE: complete this help text +# ) -class AnsibleJSONDecoder(json.JSONDecoder): - - _vaults = {} # type: dict[str, VaultLib] - - def __init__(self, *args, **kwargs): - kwargs['object_hook'] = self.object_hook - super(AnsibleJSONDecoder, self).__init__(*args, **kwargs) - - @classmethod - def set_secrets(cls, secrets): - cls._vaults['default'] = VaultLib(secrets=secrets) - - def object_hook(self, pairs): - for key in pairs: - value = pairs[key] - - if key == '__ansible_vault': - value = AnsibleVaultEncryptedUnicode(value) - if self._vaults: - value.vault = self._vaults['default'] - return value - elif key == '__ansible_unsafe': - return wrap_var(value) - - return pairs +# Imported for backward compat +from ansible.module_utils.common.json import ( # pylint: disable=unused-import + _AnsibleJSONEncoder as AnsibleJSONEncoder, + _AnsibleJSONDecoder as AnsibleJSONDecoder, +) diff --git a/lib/ansible/parsing/dataloader.py b/lib/ansible/parsing/dataloader.py index 47b6cfb12ca..4250f3d5163 100644 --- a/lib/ansible/parsing/dataloader.py +++ b/lib/ansible/parsing/dataloader.py @@ -2,23 +2,28 @@ # Copyright: (c) 2017, Ansible Project # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import annotations from __future__ import annotations import copy import os import os.path +import pathlib import re import tempfile import typing as t from ansible import constants as C from ansible.errors import AnsibleFileNotFound, AnsibleParserError +from ansible._internal._errors import _utils from ansible.module_utils.basic import is_executable +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate, SourceWasEncrypted +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils.six import binary_type, text_type from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.parsing.quoting import unquote from ansible.parsing.utils.yaml import from_yaml -from ansible.parsing.vault import VaultLib, is_encrypted, is_encrypted_file, parse_vaulttext_envelope, PromptVaultSecret +from ansible.parsing.vault import VaultLib, is_encrypted, is_encrypted_file, PromptVaultSecret from ansible.utils.path import unfrackpath from ansible.utils.display import Display @@ -73,11 +78,18 @@ class DataLoader: def set_vault_secrets(self, vault_secrets: list[tuple[str, PromptVaultSecret]] | None) -> None: self._vault.secrets = vault_secrets - def load(self, data: str, file_name: str = '', show_content: bool = True, json_only: bool = False) -> t.Any: + def load( + self, + data: str, + file_name: str | None = None, # DTFIX-RELEASE: consider deprecating this in favor of tagging Origin on data + show_content: bool = True, # DTFIX-RELEASE: consider future deprecation, but would need RedactAnnotatedSourceContext public + json_only: bool = False, + ) -> t.Any: """Backwards compat for now""" - return from_yaml(data, file_name, show_content, self._vault.secrets, json_only=json_only) + with _utils.RedactAnnotatedSourceContext.when(not show_content): + return from_yaml(data=data, file_name=file_name, json_only=json_only) - def load_from_file(self, file_name: str, cache: str = 'all', unsafe: bool = False, json_only: bool = False) -> t.Any: + def load_from_file(self, file_name: str, cache: str = 'all', unsafe: bool = False, json_only: bool = False, trusted_as_template: bool = False) -> t.Any: """ Loads data from a file, which can contain either JSON or YAML. @@ -98,16 +110,22 @@ class DataLoader: if cache != 'none' and file_name in self._FILE_CACHE: parsed_data = self._FILE_CACHE[file_name] else: - # Read the file contents and load the data structure from them - (b_file_data, show_content) = self._get_file_contents(file_name) + file_data = self.get_text_file_contents(file_name) + + if trusted_as_template: + file_data = TrustedAsTemplate().tag(file_data) + + parsed_data = self.load(data=file_data, file_name=file_name, json_only=json_only) - file_data = to_text(b_file_data, errors='surrogate_or_strict') - parsed_data = self.load(data=file_data, file_name=file_name, show_content=show_content, json_only=json_only) + # only tagging the container, used by include_vars to determine if vars should be shown or not + # this is a temporary measure until a proper data senitivity system is in place + if SourceWasEncrypted.is_tagged_on(file_data): + parsed_data = SourceWasEncrypted().tag(parsed_data) # Cache the file contents for next time based on the cache option if cache == 'all': self._FILE_CACHE[file_name] = parsed_data - elif cache == 'vaulted' and not show_content: + elif cache == 'vaulted' and SourceWasEncrypted.is_tagged_on(file_data): self._FILE_CACHE[file_name] = parsed_data # Return the parsed data, optionally deep-copied for safety @@ -137,18 +155,44 @@ class DataLoader: path = self.path_dwim(path) return is_executable(path) - def _decrypt_if_vault_data(self, b_vault_data: bytes, b_file_name: bytes | None = None) -> tuple[bytes, bool]: + def _decrypt_if_vault_data(self, b_data: bytes) -> tuple[bytes, bool]: """Decrypt b_vault_data if encrypted and return b_data and the show_content flag""" - if not is_encrypted(b_vault_data): - show_content = True - return b_vault_data, show_content + if encrypted_source := is_encrypted(b_data): + b_data = self._vault.decrypt(b_data) - b_ciphertext, b_version, cipher_name, vault_id = parse_vaulttext_envelope(b_vault_data) - b_data = self._vault.decrypt(b_vault_data, filename=b_file_name) + return b_data, not encrypted_source - show_content = False - return b_data, show_content + def get_text_file_contents(self, file_name: str, encoding: str | None = None) -> str: + """ + Returns an `Origin` tagged string with the content of the specified (DWIM-expanded for relative) file path, decrypting if necessary. + Callers must only specify `encoding` when the user can configure it, as error messages in that case will imply configurability. + If `encoding` is not specified, UTF-8 will be used. + """ + bytes_content, source_was_plaintext = self._get_file_contents(file_name) + + if encoding is None: + encoding = 'utf-8' + help_text = 'This file must be UTF-8 encoded.' + else: + help_text = 'Ensure the correct encoding was specified.' + + try: + str_content = bytes_content.decode(encoding=encoding, errors='strict') + except UnicodeDecodeError: + str_content = bytes_content.decode(encoding=encoding, errors='surrogateescape') + + display.deprecated( + msg=f"File {file_name!r} could not be decoded as {encoding!r}. Invalid content has been escaped.", + version="2.23", + # obj intentionally omitted since there's no value in showing its contents + help_text=help_text, + ) + + if not source_was_plaintext: + str_content = SourceWasEncrypted().tag(str_content) + + return AnsibleTagHelper.tag_copy(bytes_content, str_content) def _get_file_contents(self, file_name: str) -> tuple[bytes, bool]: """ @@ -163,21 +207,22 @@ class DataLoader: :raises AnsibleParserError: if we were unable to read the file :return: Returns a byte string of the file contents """ - if not file_name or not isinstance(file_name, (binary_type, text_type)): - raise AnsibleParserError("Invalid filename: '%s'" % to_native(file_name)) + if not file_name or not isinstance(file_name, str): + raise TypeError(f"Invalid filename {file_name!r}.") - b_file_name = to_bytes(self.path_dwim(file_name)) - # This is what we really want but have to fix unittests to make it pass - # if not os.path.exists(b_file_name) or not os.path.isfile(b_file_name): - if not self.path_exists(b_file_name): - raise AnsibleFileNotFound("Unable to retrieve file contents", file_name=file_name) + file_name = self.path_dwim(file_name) try: - with open(b_file_name, 'rb') as f: - data = f.read() - return self._decrypt_if_vault_data(data, b_file_name) - except (IOError, OSError) as e: - raise AnsibleParserError("an error occurred while trying to read the file '%s': %s" % (file_name, to_native(e)), orig_exc=e) + data = pathlib.Path(file_name).read_bytes() + except FileNotFoundError as ex: + # DTFIX-FUTURE: why not just let the builtin one fly? + raise AnsibleFileNotFound("Unable to retrieve file contents.", file_name=file_name) from ex + except (IOError, OSError) as ex: + raise AnsibleParserError(f"An error occurred while trying to read the file {file_name!r}.") from ex + + data = Origin(path=file_name).tag(data) + + return self._decrypt_if_vault_data(data) def get_basedir(self) -> str: """ returns the current basedir """ @@ -194,8 +239,8 @@ class DataLoader: make relative paths work like folks expect. """ - given = unquote(given) given = to_text(given, errors='surrogate_or_strict') + given = unquote(given) if given.startswith(to_text(os.path.sep)) or given.startswith(u'~'): path = given @@ -392,19 +437,19 @@ class DataLoader: # if the file is encrypted and no password was specified, # the decrypt call would throw an error, but we check first # since the decrypt function doesn't know the file name - data = f.read() + data = Origin(path=real_path).tag(f.read()) if not self._vault.secrets: raise AnsibleParserError("A vault password or secret must be specified to decrypt %s" % to_native(file_path)) - data = self._vault.decrypt(data, filename=real_path) + data = self._vault.decrypt(data) # Make a temp file real_path = self._create_content_tempfile(data) self._tempfiles.add(real_path) return real_path - except (IOError, OSError) as e: - raise AnsibleParserError("an error occurred while trying to read the file '%s': %s" % (to_native(real_path), to_native(e)), orig_exc=e) + except (IOError, OSError) as ex: + raise AnsibleParserError(f"an error occurred while trying to read the file {to_text(real_path)!r}.") from ex def cleanup_tmp_file(self, file_path: str) -> None: """ diff --git a/lib/ansible/parsing/mod_args.py b/lib/ansible/parsing/mod_args.py index aed543d0953..c19d56e91df 100644 --- a/lib/ansible/parsing/mod_args.py +++ b/lib/ansible/parsing/mod_args.py @@ -19,12 +19,14 @@ from __future__ import annotations import ansible.constants as C from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils.six import string_types from ansible.module_utils.common.sentinel import Sentinel from ansible.module_utils.common.text.converters import to_text from ansible.parsing.splitter import parse_kv, split_args +from ansible.parsing.vault import EncryptedString from ansible.plugins.loader import module_loader, action_loader -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.utils.fqcn import add_internal_fqcns @@ -129,9 +131,7 @@ class ModuleArgsParser: self._task_attrs.update(['local_action', 'static']) self._task_attrs = frozenset(self._task_attrs) - self.resolved_action = None - - def _split_module_string(self, module_string): + def _split_module_string(self, module_string: str) -> tuple[str, str]: """ when module names are expressed like: action: copy src=a dest=b @@ -141,9 +141,11 @@ class ModuleArgsParser: tokens = split_args(module_string) if len(tokens) > 1: - return (tokens[0].strip(), " ".join(tokens[1:])) + result = (tokens[0].strip(), " ".join(tokens[1:])) else: - return (tokens[0].strip(), "") + result = (tokens[0].strip(), "") + + return AnsibleTagHelper.tag_copy(module_string, result[0]), AnsibleTagHelper.tag_copy(module_string, result[1]) def _normalize_parameters(self, thing, action=None, additional_args=None): """ @@ -157,9 +159,9 @@ class ModuleArgsParser: # than those which may be parsed/normalized next final_args = dict() if additional_args: - if isinstance(additional_args, string_types): - templar = Templar(loader=None) - if templar.is_template(additional_args): + if isinstance(additional_args, (str, EncryptedString)): + # DTFIX-RELEASE: should this be is_possibly_template? + if TemplateEngine().is_template(additional_args): final_args['_variable_params'] = additional_args else: raise AnsibleParserError("Complex args containing variables cannot use bare variables (without Jinja2 delimiters), " @@ -224,6 +226,8 @@ class ModuleArgsParser: # form is like: copy: src=a dest=b check_raw = action in FREEFORM_ACTIONS args = parse_kv(thing, check_raw=check_raw) + elif isinstance(thing, EncryptedString): + args = dict(_raw_params=thing) elif thing is None: # this can happen with modules which take no params, like ping: args = None @@ -276,8 +280,6 @@ class ModuleArgsParser: task, dealing with all sorts of levels of fuzziness. """ - thing = None - action = None delegate_to = self._task_ds.get('delegate_to', Sentinel) args = dict() @@ -292,7 +294,7 @@ class ModuleArgsParser: if 'action' in self._task_ds: # an old school 'action' statement thing = self._task_ds['action'] - action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) + action, args = self._normalize_parameters(thing, additional_args=additional_args) # local_action if 'local_action' in self._task_ds: @@ -301,12 +303,7 @@ class ModuleArgsParser: raise AnsibleParserError("action and local_action are mutually exclusive", obj=self._task_ds) thing = self._task_ds.get('local_action', '') delegate_to = 'localhost' - action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) - - if action is not None and not skip_action_validation: - context = _get_action_context(action, self._collection_list) - if context is not None and context.resolved: - self.resolved_action = context.resolved_fqcn + action, args = self._normalize_parameters(thing, additional_args=additional_args) # module: is the more new-style invocation @@ -315,14 +312,13 @@ class ModuleArgsParser: # walk the filtered input dictionary to see if we recognize a module name for item, value in non_task_ds.items(): - context = None - is_action_candidate = False if item in BUILTIN_TASKS: is_action_candidate = True elif skip_action_validation: is_action_candidate = True else: try: + # DTFIX-FUTURE: extract to a helper method, shared with Task.post_validate_args context = _get_action_context(item, self._collection_list) except AnsibleError as e: if e.obj is None: @@ -336,9 +332,6 @@ class ModuleArgsParser: if action is not None: raise AnsibleParserError("conflicting action statements: %s, %s" % (action, item), obj=self._task_ds) - if context is not None and context.resolved: - self.resolved_action = context.resolved_fqcn - action = item thing = value action, args = self._normalize_parameters(thing, action=action, additional_args=additional_args) @@ -353,14 +346,5 @@ class ModuleArgsParser: else: raise AnsibleParserError("no module/action detected in task.", obj=self._task_ds) - elif args.get('_raw_params', '') != '' and action not in RAW_PARAM_MODULES: - templar = Templar(loader=None) - raw_params = args.pop('_raw_params') - if templar.is_template(raw_params): - args['_variable_params'] = raw_params - else: - raise AnsibleParserError( - "this task '%s' has extra params, which is only allowed in the following modules: %s" % (action, ", ".join(RAW_PARAM_MODULES_SIMPLE)), - obj=self._task_ds) - return (action, args, delegate_to) + return action, args, delegate_to diff --git a/lib/ansible/parsing/plugin_docs.py b/lib/ansible/parsing/plugin_docs.py index c18230806b7..f986ec67f46 100644 --- a/lib/ansible/parsing/plugin_docs.py +++ b/lib/ansible/parsing/plugin_docs.py @@ -4,13 +4,15 @@ from __future__ import annotations import ast -import tokenize + +import yaml from ansible import constants as C from ansible.errors import AnsibleError, AnsibleParserError from ansible.module_utils.common.text.converters import to_text, to_native from ansible.parsing.yaml.loader import AnsibleLoader from ansible.utils.display import Display +from ansible._internal._datatag import _tags display = Display() @@ -23,13 +25,6 @@ string_to_vars = { } -def _var2string(value): - """ reverse lookup of the dict above """ - for k, v in string_to_vars.items(): - if v == value: - return k - - def _init_doc_dict(): """ initialize a return dict for docs with the expected structure """ return {k: None for k in string_to_vars.values()} @@ -43,13 +38,14 @@ def read_docstring_from_yaml_file(filename, verbose=True, ignore_errors=True): try: with open(filename, 'rb') as yamlfile: - file_data = AnsibleLoader(yamlfile.read(), file_name=filename).get_single_data() - except Exception as e: - msg = "Unable to parse yaml file '%s': %s" % (filename, to_native(e)) + file_data = yaml.load(yamlfile, Loader=AnsibleLoader) + except Exception as ex: + msg = f"Unable to parse yaml file {filename}" + # DTFIX-RELEASE: find a better pattern for this (can we use the new optional error behavior?) if not ignore_errors: - raise AnsibleParserError(msg, orig_exc=e) + raise AnsibleParserError(f'{msg}.') from ex elif verbose: - display.error(msg) + display.error(f'{msg}: {ex}') if file_data: for key in string_to_vars: @@ -58,74 +54,11 @@ def read_docstring_from_yaml_file(filename, verbose=True, ignore_errors=True): return data -def read_docstring_from_python_module(filename, verbose=True, ignore_errors=True): - """ - Use tokenization to search for assignment of the documentation variables in the given file. - Parse from YAML and return the resulting python structure or None together with examples as plain text. - """ - - seen = set() - data = _init_doc_dict() - - next_string = None - with tokenize.open(filename) as f: - tokens = tokenize.generate_tokens(f.readline) - for token in tokens: - - # found label that looks like variable - if token.type == tokenize.NAME: - - # label is expected value, in correct place and has not been seen before - if token.start == 1 and token.string in string_to_vars and token.string not in seen: - # next token that is string has the docs - next_string = string_to_vars[token.string] - continue - - # previous token indicated this string is a doc string - if next_string is not None and token.type == tokenize.STRING: - - # ensure we only process one case of it - seen.add(token.string) - - value = token.string - - # strip string modifiers/delimiters - if value.startswith(('r', 'b')): - value = value.lstrip('rb') - - if value.startswith(("'", '"')): - value = value.strip("'\"") - - # actually use the data - if next_string == 'plainexamples': - # keep as string, can be yaml, but we let caller deal with it - data[next_string] = to_text(value) - else: - # yaml load the data - try: - data[next_string] = AnsibleLoader(value, file_name=filename).get_single_data() - except Exception as e: - msg = "Unable to parse docs '%s' in python file '%s': %s" % (_var2string(next_string), filename, to_native(e)) - if not ignore_errors: - raise AnsibleParserError(msg, orig_exc=e) - elif verbose: - display.error(msg) - - next_string = None - - # if nothing else worked, fall back to old method - if not seen: - data = read_docstring_from_python_file(filename, verbose, ignore_errors) - - return data - - def read_docstring_from_python_file(filename, verbose=True, ignore_errors=True): """ Use ast to search for assignment of the DOCUMENTATION and EXAMPLES variables in the given file. Parse DOCUMENTATION from YAML and return the YAML doc or None together with EXAMPLES, as plain text. """ - data = _init_doc_dict() try: @@ -153,16 +86,18 @@ def read_docstring_from_python_file(filename, verbose=True, ignore_errors=True): data[varkey] = to_text(child.value.value) else: # string should be yaml if already not a dict - data[varkey] = AnsibleLoader(child.value.value, file_name=filename).get_single_data() + child_value = _tags.Origin(path=filename, line_num=child.value.lineno).tag(child.value.value) + data[varkey] = yaml.load(child_value, Loader=AnsibleLoader) display.debug('Documentation assigned: %s' % varkey) - except Exception as e: - msg = "Unable to parse documentation in python file '%s': %s" % (filename, to_native(e)) + except Exception as ex: + msg = f"Unable to parse documentation in python file {filename!r}" + # DTFIX-RELEASE: better pattern to conditionally raise/display if not ignore_errors: - raise AnsibleParserError(msg, orig_exc=e) + raise AnsibleParserError(f'{msg}.') from ex elif verbose: - display.error(msg) + display.error(f'{msg}: {ex}.') return data @@ -174,7 +109,7 @@ def read_docstring(filename, verbose=True, ignore_errors=True): if filename.endswith(C.YAML_DOC_EXTENSIONS): docstring = read_docstring_from_yaml_file(filename, verbose=verbose, ignore_errors=ignore_errors) elif filename.endswith(C.PYTHON_DOC_EXTENSIONS): - docstring = read_docstring_from_python_module(filename, verbose=verbose, ignore_errors=ignore_errors) + docstring = read_docstring_from_python_file(filename, verbose=verbose, ignore_errors=ignore_errors) elif not ignore_errors: raise AnsibleError("Unknown documentation format: %s" % to_native(filename)) @@ -221,6 +156,6 @@ def read_docstub(filename): in_documentation = True short_description = r''.join(doc_stub).strip().rstrip('.') - data = AnsibleLoader(short_description, file_name=filename).get_single_data() + data = yaml.load(_tags.Origin(path=str(filename)).tag(short_description), Loader=AnsibleLoader) return data diff --git a/lib/ansible/parsing/splitter.py b/lib/ansible/parsing/splitter.py index 3f61347a4ac..18ef976496e 100644 --- a/lib/ansible/parsing/splitter.py +++ b/lib/ansible/parsing/splitter.py @@ -22,6 +22,8 @@ import re from ansible.errors import AnsibleParserError from ansible.module_utils.common.text.converters import to_text +from ansible.module_utils._internal._datatag import AnsibleTagHelper +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate from ansible.parsing.quoting import unquote @@ -52,6 +54,13 @@ def parse_kv(args, check_raw=False): they will simply be ignored. """ + tags = [] + if origin_tag := Origin.get_tag(args): + # NB: adjusting the column number is left as an exercise for the reader + tags.append(origin_tag) + if trusted_tag := TrustedAsTemplate.get_tag(args): + tags.append(trusted_tag) + args = to_text(args, nonstring='passthru') options = {} @@ -90,6 +99,12 @@ def parse_kv(args, check_raw=False): if len(raw_params) > 0: options[u'_raw_params'] = join_args(raw_params) + if tags: + options = {AnsibleTagHelper.tag(k, tags): AnsibleTagHelper.tag(v, tags) for k, v in options.items()} + + if origin_tag: + options = origin_tag.tag(options) + return options diff --git a/lib/ansible/parsing/utils/jsonify.py b/lib/ansible/parsing/utils/jsonify.py deleted file mode 100644 index 0ebd7564094..00000000000 --- a/lib/ansible/parsing/utils/jsonify.py +++ /dev/null @@ -1,36 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations - -import json - - -def jsonify(result, format=False): - """ format JSON output (uncompressed or uncompressed) """ - - if result is None: - return "{}" - - indent = None - if format: - indent = 4 - - try: - return json.dumps(result, sort_keys=True, indent=indent, ensure_ascii=False) - except UnicodeDecodeError: - return json.dumps(result, sort_keys=True, indent=indent) diff --git a/lib/ansible/parsing/utils/yaml.py b/lib/ansible/parsing/utils/yaml.py index 9462eba8aa9..f1cd142dc0e 100644 --- a/lib/ansible/parsing/utils/yaml.py +++ b/lib/ansible/parsing/utils/yaml.py @@ -6,77 +6,48 @@ from __future__ import annotations import json +import typing as t -from yaml import YAMLError +import yaml -from ansible.errors import AnsibleParserError -from ansible.errors.yaml_strings import YAML_SYNTAX_ERROR -from ansible.module_utils.common.text.converters import to_native +from ansible.errors import AnsibleJSONParserError +from ansible._internal._errors import _utils +from ansible.parsing.vault import VaultSecret from ansible.parsing.yaml.loader import AnsibleLoader -from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject -from ansible.parsing.ajson import AnsibleJSONDecoder +from ansible._internal._yaml._errors import AnsibleYAMLParserError +from ansible._internal._datatag._tags import Origin +from ansible._internal._json._profiles import _legacy -__all__ = ('from_yaml',) +def from_yaml( + data: str, + file_name: str | None = None, + show_content: bool = True, + vault_secrets: list[tuple[str, VaultSecret]] | None = None, # deprecated: description='Deprecate vault_secrets, it has no effect.' core_version='2.23' + json_only: bool = False, +) -> t.Any: + """Creates a Python data structure from the given data, which can be either a JSON or YAML string.""" + # FUTURE: provide Ansible-specific top-level APIs to expose JSON and YAML serialization/deserialization to hide the error handling logic + # once those are in place, defer deprecate this entire function + origin = Origin.get_or_create_tag(data, file_name) -def _handle_error(json_exc, yaml_exc, file_name, show_content): - """ - Optionally constructs an object (AnsibleBaseYAMLObject) to encapsulate the - file name/position where a YAML exception occurred, and raises an AnsibleParserError - to display the syntax exception information. - """ + data = origin.tag(data) - # if the YAML exception contains a problem mark, use it to construct - # an object the error class can use to display the faulty line - err_obj = None - if hasattr(yaml_exc, 'problem_mark'): - err_obj = AnsibleBaseYAMLObject() - err_obj.ansible_pos = (file_name, yaml_exc.problem_mark.line + 1, yaml_exc.problem_mark.column + 1) - - n_yaml_syntax_error = YAML_SYNTAX_ERROR % to_native(getattr(yaml_exc, 'problem', u'')) - n_err_msg = 'We were unable to read either as JSON nor YAML, these are the errors we got from each:\n' \ - 'JSON: %s\n\n%s' % (to_native(json_exc), n_yaml_syntax_error) - - raise AnsibleParserError(n_err_msg, obj=err_obj, show_content=show_content, orig_exc=yaml_exc) - - -def _safe_load(stream, file_name=None, vault_secrets=None): - """ Implements yaml.safe_load(), except using our custom loader class. """ - - loader = AnsibleLoader(stream, file_name, vault_secrets) - try: - return loader.get_single_data() - finally: + with _utils.RedactAnnotatedSourceContext.when(not show_content): try: - loader.dispose() - except AttributeError: - pass # older versions of yaml don't have dispose function, ignore - - -def from_yaml(data, file_name='', show_content=True, vault_secrets=None, json_only=False): - """ - Creates a python datastructure from the given data, which can be either - a JSON or YAML string. - """ - new_data = None - - try: - # in case we have to deal with vaults - AnsibleJSONDecoder.set_secrets(vault_secrets) - - # we first try to load this data as JSON. - # Fixes issues with extra vars json strings not being parsed correctly by the yaml parser - new_data = json.loads(data, cls=AnsibleJSONDecoder) - except Exception as json_exc: + # we first try to load this data as JSON. + # Fixes issues with extra vars json strings not being parsed correctly by the yaml parser + return json.loads(data, cls=_legacy.Decoder) + except Exception as ex: + json_ex = ex if json_only: - raise AnsibleParserError(to_native(json_exc), orig_exc=json_exc) + AnsibleJSONParserError.handle_exception(json_ex, origin=origin) - # must not be JSON, let the rest try try: - new_data = _safe_load(data, file_name=file_name, vault_secrets=vault_secrets) - except YAMLError as yaml_exc: - _handle_error(json_exc, yaml_exc, file_name, show_content) - - return new_data + return yaml.load(data, Loader=AnsibleLoader) # type: ignore[arg-type] + except Exception as yaml_ex: + # DTFIX-RELEASE: how can we indicate in Origin that the data is in-memory only, to support context information -- is that useful? + # we'd need to pass data to handle_exception so it could be used as the content instead of reading from disk + AnsibleYAMLParserError.handle_exception(yaml_ex, origin=origin) diff --git a/lib/ansible/parsing/vault/__init__.py b/lib/ansible/parsing/vault/__init__.py index e3121b5dbb9..0cf19dcb4d5 100644 --- a/lib/ansible/parsing/vault/__init__.py +++ b/lib/ansible/parsing/vault/__init__.py @@ -19,6 +19,7 @@ from __future__ import annotations import errno import fcntl +import functools import os import random import shlex @@ -27,11 +28,18 @@ import subprocess import sys import tempfile import warnings +import typing as t from binascii import hexlify from binascii import unhexlify from binascii import Error as BinasciiError +from ansible.module_utils._internal._datatag import ( + AnsibleTagHelper, AnsibleTaggedObject, _AnsibleTagsMapping, _EmptyROInternalTagsMapping, _EMPTY_INTERNAL_TAGS_MAPPING, +) +from ansible._internal._templating import _jinja_common +from ansible._internal._datatag._tags import Origin, VaultedValue, TrustedAsTemplate + HAS_CRYPTOGRAPHY = False CRYPTOGRAPHY_BACKEND = None try: @@ -141,11 +149,13 @@ def _parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None): vault_id = to_text(b_tmpheader[3].strip()) b_ciphertext = b''.join(b_tmpdata[1:]) + # DTFIX-RELEASE: possible candidate for propagate_origin + b_ciphertext = AnsibleTagHelper.tag_copy(b_vaulttext_envelope, b_ciphertext) return b_ciphertext, b_version, cipher_name, vault_id -def parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None, filename=None): +def parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None): """Parse the vaulttext envelope When data is saved, it has a header prepended and is formatted into 80 @@ -153,11 +163,8 @@ def parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None, filena and then removes the header and the inserted newlines. The string returned is suitable for processing by the Cipher classes. - :arg b_vaulttext: byte str containing the data from a save file - :kwarg default_vault_id: The vault_id name to use if the vaulttext does not provide one. - :kwarg filename: The filename that the data came from. This is only - used to make better error messages in case the data cannot be - decrypted. This is optional. + :arg b_vaulttext_envelope: byte str containing the data from a save file + :arg default_vault_id: The vault_id name to use if the vaulttext does not provide one. :returns: A tuple of byte str of the vaulttext suitable to pass to parse_vaultext, a byte str of the vault format version, the name of the cipher used, and the vault_id. @@ -168,12 +175,8 @@ def parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id=None, filena try: return _parse_vaulttext_envelope(b_vaulttext_envelope, default_vault_id) - except Exception as exc: - msg = "Vault envelope format error" - if filename: - msg += ' in %s' % (filename) - msg += ': %s' % exc - raise AnsibleVaultFormatError(msg) + except Exception as ex: + raise AnsibleVaultFormatError("Vault envelope format error.", obj=b_vaulttext_envelope) from ex def format_vaulttext_envelope(b_ciphertext, cipher_name, version=None, vault_id=None): @@ -219,9 +222,10 @@ def format_vaulttext_envelope(b_ciphertext, cipher_name, version=None, vault_id= def _unhexlify(b_data): try: - return unhexlify(b_data) - except (BinasciiError, TypeError) as exc: - raise AnsibleVaultFormatError('Vault format unhexlify error: %s' % exc) + # DTFIX-RELEASE: possible candidate for propagate_origin + return AnsibleTagHelper.tag_copy(b_data, unhexlify(b_data)) + except (BinasciiError, TypeError) as ex: + raise AnsibleVaultFormatError('Vault format unhexlify error.', obj=b_data) from ex def _parse_vaulttext(b_vaulttext): @@ -247,9 +251,8 @@ def parse_vaulttext(b_vaulttext): return _parse_vaulttext(b_vaulttext) except AnsibleVaultFormatError: raise - except Exception as exc: - msg = "Vault vaulttext format error: %s" % exc - raise AnsibleVaultFormatError(msg) + except Exception as ex: + raise AnsibleVaultFormatError("Vault vaulttext format error.", obj=b_vaulttext) from ex def verify_secret_is_not_empty(secret, msg=None): @@ -414,7 +417,7 @@ class FileVaultSecret(VaultSecret): except (OSError, IOError) as e: raise AnsibleError("Could not read vault password file %s: %s" % (filename, e)) - b_vault_data, dummy = self.loader._decrypt_if_vault_data(vault_pass, filename) + b_vault_data, dummy = self.loader._decrypt_if_vault_data(vault_pass) vault_pass = b_vault_data.strip(b'\r\n') @@ -633,58 +636,44 @@ class VaultLib: vault_id=vault_id) return b_vaulttext - def decrypt(self, vaulttext, filename=None, obj=None): + def decrypt(self, vaulttext): """Decrypt a piece of vault encrypted data. :arg vaulttext: a string to decrypt. Since vault encrypted data is an ascii text format this can be either a byte str or unicode string. - :kwarg filename: a filename that the data came from. This is only - used to make better error messages in case the data cannot be - decrypted. - :returns: a byte string containing the decrypted data and the vault-id that was used - + :returns: a byte string containing the decrypted data """ - plaintext, vault_id, vault_secret = self.decrypt_and_get_vault_id(vaulttext, filename=filename, obj=obj) + plaintext, vault_id, vault_secret = self.decrypt_and_get_vault_id(vaulttext) return plaintext - def decrypt_and_get_vault_id(self, vaulttext, filename=None, obj=None): + def decrypt_and_get_vault_id(self, vaulttext): """Decrypt a piece of vault encrypted data. :arg vaulttext: a string to decrypt. Since vault encrypted data is an ascii text format this can be either a byte str or unicode string. - :kwarg filename: a filename that the data came from. This is only - used to make better error messages in case the data cannot be - decrypted. :returns: a byte string containing the decrypted data and the vault-id vault-secret that was used - """ - b_vaulttext = to_bytes(vaulttext, errors='strict', encoding='utf-8') + origin = Origin.get_tag(vaulttext) + + b_vaulttext = to_bytes(vaulttext, nonstring='error') # enforce vaulttext is str/bytes, keep type check if removing type conversion if self.secrets is None: - msg = "A vault password must be specified to decrypt data" - if filename: - msg += " in file %s" % to_native(filename) - raise AnsibleVaultError(msg) + raise AnsibleVaultError("A vault password must be specified to decrypt data.", obj=vaulttext) if not is_encrypted(b_vaulttext): - msg = "input is not vault encrypted data. " - if filename: - msg += "%s is not a vault encrypted file" % to_native(filename) - raise AnsibleError(msg) + raise AnsibleVaultError("Input is not vault encrypted data.", obj=vaulttext) - b_vaulttext, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext, filename=filename) + b_vaulttext, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext) # create the cipher object, note that the cipher used for decrypt can # be different than the cipher used for encrypt if cipher_name in CIPHER_ALLOWLIST: this_cipher = CIPHER_MAPPING[cipher_name]() else: - raise AnsibleError("{0} cipher could not be found".format(cipher_name)) - - b_plaintext = None + raise AnsibleVaultError(f"Cipher {cipher_name!r} could not be found.", obj=vaulttext) if not self.secrets: - raise AnsibleVaultError('Attempting to decrypt but no vault secrets found') + raise AnsibleVaultError('Attempting to decrypt but no vault secrets found.', obj=vaulttext) # WARNING: Currently, the vault id is not required to match the vault id in the vault blob to # decrypt a vault properly. The vault id in the vault blob is not part of the encrypted @@ -697,15 +686,13 @@ class VaultLib: # we check it first. vault_id_matchers = [] - vault_id_used = None - vault_secret_used = None if vault_id: display.vvvvv(u'Found a vault_id (%s) in the vaulttext' % to_text(vault_id)) vault_id_matchers.append(vault_id) _matches = match_secrets(self.secrets, vault_id_matchers) if _matches: - display.vvvvv(u'We have a secret associated with vault id (%s), will try to use to decrypt %s' % (to_text(vault_id), to_text(filename))) + display.vvvvv(u'We have a secret associated with vault id (%s), will try to use to decrypt %s' % (to_text(vault_id), to_text(origin))) else: display.vvvvv(u'Found a vault_id (%s) in the vault text, but we do not have a associated secret (--vault-id)' % to_text(vault_id)) @@ -719,45 +706,32 @@ class VaultLib: # for vault_secret_id in vault_secret_ids: for vault_secret_id, vault_secret in matched_secrets: - display.vvvvv(u'Trying to use vault secret=(%s) id=%s to decrypt %s' % (to_text(vault_secret), to_text(vault_secret_id), to_text(filename))) + display.vvvvv(u'Trying to use vault secret=(%s) id=%s to decrypt %s' % (to_text(vault_secret), to_text(vault_secret_id), to_text(origin))) try: # secret = self.secrets[vault_secret_id] display.vvvv(u'Trying secret %s for vault_id=%s' % (to_text(vault_secret), to_text(vault_secret_id))) b_plaintext = this_cipher.decrypt(b_vaulttext, vault_secret) + # DTFIX-RELEASE: possible candidate for propagate_origin + b_plaintext = AnsibleTagHelper.tag_copy(vaulttext, b_plaintext) if b_plaintext is not None: vault_id_used = vault_secret_id vault_secret_used = vault_secret file_slug = '' - if filename: - file_slug = ' of "%s"' % filename + if origin: + file_slug = ' of "%s"' % origin display.vvvvv( u'Decrypt%s successful with secret=%s and vault_id=%s' % (to_text(file_slug), to_text(vault_secret), to_text(vault_secret_id)) ) break - except AnsibleVaultFormatError as exc: - exc.obj = obj - msg = u"There was a vault format error" - if filename: - msg += u' in %s' % (to_text(filename)) - msg += u': %s' % to_text(exc) - display.warning(msg, formatted=True) + except AnsibleVaultFormatError: raise except AnsibleError as e: display.vvvv(u'Tried to use the vault secret (%s) to decrypt (%s) but it failed. Error: %s' % - (to_text(vault_secret_id), to_text(filename), e)) + (to_text(vault_secret_id), to_text(origin), e)) continue else: - msg = "Decryption failed (no vault secrets were found that could decrypt)" - if filename: - msg += " on %s" % to_native(filename) - raise AnsibleVaultError(msg) - - if b_plaintext is None: - msg = "Decryption failed" - if filename: - msg += " on %s" % to_native(filename) - raise AnsibleError(msg) + raise AnsibleVaultError("Decryption failed (no vault secrets were found that could decrypt).", obj=vaulttext) return b_plaintext, vault_id_used, vault_secret_used @@ -916,7 +890,7 @@ class VaultEditor: ciphertext = self.read_data(filename) try: - plaintext = self.vault.decrypt(ciphertext, filename=filename) + plaintext = self.vault.decrypt(ciphertext) except AnsibleError as e: raise AnsibleError("%s for %s" % (to_native(e), to_native(filename))) self.write_data(plaintext, output_file or filename, shred=False) @@ -956,7 +930,7 @@ class VaultEditor: # Figure out the vault id from the file, to select the right secret to re-encrypt it # (duplicates parts of decrypt, but alas...) - dummy, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext, filename=filename) + dummy, dummy, cipher_name, vault_id = parse_vaulttext_envelope(b_vaulttext) # vault id here may not be the vault id actually used for decrypting # as when the edited file has no vault-id but is decrypted by non-default id in secrets @@ -974,7 +948,7 @@ class VaultEditor: vaulttext = to_text(b_vaulttext) try: - plaintext = self.vault.decrypt(vaulttext, filename=filename) + plaintext = self.vault.decrypt(vaulttext) return plaintext except AnsibleError as e: raise AnsibleVaultError("%s for %s" % (to_native(e), to_native(filename))) @@ -1024,10 +998,12 @@ class VaultEditor: try: if filename == '-': - data = sys.stdin.buffer.read() + data = Origin(description='').tag(sys.stdin.buffer.read()) else: + filename = os.path.abspath(filename) + with open(filename, "rb") as fh: - data = fh.read() + data = Origin(path=filename).tag(fh.read()) except Exception as e: msg = to_native(e) if not msg: @@ -1170,6 +1146,7 @@ class VaultAES256: return b_derivedkey @classmethod + @functools.cache # Concurrent first-use by multiple threads will all execute the method body. def _gen_key_initctr(cls, b_password, b_salt): # 16 for AES 128, 32 for AES256 key_length = 32 @@ -1302,3 +1279,258 @@ class VaultAES256: CIPHER_MAPPING = { u'AES256': VaultAES256, } + + +class VaultSecretsContext: + """Provides context-style access to vault secrets.""" + _current: t.ClassVar[t.Self | None] = None + + def __init__(self, secrets: list[tuple[str, VaultSecret]]) -> None: + self.secrets = secrets + + @classmethod + def initialize(cls, value: t.Self) -> None: + """ + Initialize VaultSecretsContext with the specified instance and secrets (since it's not a lazy or per-thread context). + This method will fail if called more than once. + """ + if cls._current: + raise RuntimeError(f"The {cls.__name__} context is already initialized.") + + cls._current = value + + @classmethod + def current(cls, optional: bool = False) -> t.Self: + """Access vault secrets, if initialized, ala `AmbientContextBase.current()`.""" + if not cls._current and not optional: + raise ReferenceError(f"A required {cls.__name__} context is not active.") + + return cls._current + + +@t.final +class EncryptedString(AnsibleTaggedObject): + """ + An encrypted string which supports tagging and on-demand decryption. + All methods provided by Python's built-in `str` are supported, all of which operate on the decrypted value. + Any attempt to use this value when it cannot be decrypted will raise an exception. + Despite supporting `str` methods, access to an instance of this type through templating is recommended over direct access. + """ + + __slots__ = ('_ciphertext', '_plaintext', '_ansible_tags_mapping') + + _subclasses_native_type: t.ClassVar[bool] = False + _empty_tags_as_native: t.ClassVar[bool] = False + + _ciphertext: str + _plaintext: str | None + _ansible_tags_mapping: _AnsibleTagsMapping | _EmptyROInternalTagsMapping + + def __init__(self, *, ciphertext: str) -> None: + if type(ciphertext) is not str: # pylint: disable=unidiomatic-typecheck + raise TypeError(f'ciphertext must be {str} instead of {type(ciphertext)}') + + object.__setattr__(self, '_ciphertext', ciphertext) + object.__setattr__(self, '_plaintext', None) + object.__setattr__(self, '_ansible_tags_mapping', _EMPTY_INTERNAL_TAGS_MAPPING) + + @classmethod + def _instance_factory(cls, value: t.Any, tags_mapping: _AnsibleTagsMapping) -> EncryptedString: + instance = EncryptedString.__new__(EncryptedString) + + # In 2.18 and earlier, vaulted values were not trusted. + # This maintains backwards compatibility with that. + # Additionally, supporting templating on vaulted values could be problematic for a few cases: + # 1) There's no way to compose YAML tags, so you can't use `!unsafe` and `!vault` together. + # 2) It would make composing `EncryptedString` with a possible future `TemplateString` more difficult. + tags_mapping.pop(TrustedAsTemplate, None) + + object.__setattr__(instance, '_ciphertext', value._ciphertext) + object.__setattr__(instance, '_plaintext', value._plaintext) + object.__setattr__(instance, '_ansible_tags_mapping', tags_mapping) + + return instance + + def __setstate__(self, state: tuple[None, dict[str, t.Any]]) -> None: + for key, value in state[1].items(): + object.__setattr__(self, key, value) + + def __delattr__(self, item: str) -> t.NoReturn: + raise AttributeError(f'{self.__class__.__name__!r} object is read-only') + + def __setattr__(self, key: str, value: object) -> t.NoReturn: + raise AttributeError(f'{self.__class__.__name__!r} object is read-only') + + @classmethod + def _init_class(cls) -> None: + """ + Add proxies for the specified `str` methods. + These proxies operate on the plaintext, which is decrypted on-demand. + """ + cls._native_type = cls + + operator_method_names = ( + '__eq__', + '__ge__', + '__gt__', + '__le__', + '__lt__', + '__ne__', + ) + + method_names = ( + '__add__', + '__contains__', + '__format__', + '__getitem__', + '__hash__', + '__iter__', + '__len__', + '__mod__', + '__mul__', + '__rmod__', + '__rmul__', + 'capitalize', + 'casefold', + 'center', + 'count', + 'encode', + 'endswith', + 'expandtabs', + 'find', + 'format', + 'format_map', + 'index', + 'isalnum', + 'isalpha', + 'isascii', + 'isdecimal', + 'isdigit', + 'isidentifier', + 'islower', + 'isnumeric', + 'isprintable', + 'isspace', + 'istitle', + 'isupper', + 'join', + 'ljust', + 'lower', + 'lstrip', + 'maketrans', # static, but implemented for simplicty/consistency + 'partition', + 'removeprefix', + 'removesuffix', + 'replace', + 'rfind', + 'rindex', + 'rjust', + 'rpartition', + 'rsplit', + 'rstrip', + 'split', + 'splitlines', + 'startswith', + 'strip', + 'swapcase', + 'title', + 'translate', + 'upper', + 'zfill', + ) + + for method_name in operator_method_names: + setattr(cls, method_name, functools.partialmethod(cls._proxy_str_operator_method, getattr(str, method_name))) + + for method_name in method_names: + setattr(cls, method_name, functools.partialmethod(cls._proxy_str_method, getattr(str, method_name))) + + def _decrypt(self) -> str: + """ + Attempt to decrypt the ciphertext and return the plaintext, which will be cached. + If decryption fails an exception will be raised and no result will be cached. + """ + if self._plaintext is None: + vault = VaultLib(secrets=VaultSecretsContext.current().secrets) + # use the utility method to ensure that origin tags are available + plaintext = to_text(vault.decrypt(VaultHelper.get_ciphertext(self, with_tags=True))) # raises if the ciphertext cannot be decrypted + + # propagate source value tags plus VaultedValue for round-tripping ciphertext + plaintext = AnsibleTagHelper.tag(plaintext, AnsibleTagHelper.tags(self) | {VaultedValue(ciphertext=self._ciphertext)}) + + object.__setattr__(self, '_plaintext', plaintext) + + return self._plaintext + + def _as_dict(self) -> t.Dict[str, t.Any]: + return dict( + value=self._ciphertext, + tags=list(self._ansible_tags_mapping.values()), + ) + + def _native_copy(self) -> str: + return AnsibleTagHelper.untag(self._decrypt()) + + def _proxy_str_operator_method(self, method: t.Callable, other) -> t.Any: + obj = self._decrypt() + + if type(other) is EncryptedString: # pylint: disable=unidiomatic-typecheck + other = other._decrypt() + + return method(obj, other) + + def _proxy_str_method(self, method: t.Callable, *args, **kwargs) -> t.Any: + obj = self._decrypt() + return method(obj, *args, **kwargs) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(ciphertext={self._ciphertext!r})' + + def __str__(self) -> str: + return self._decrypt() + + def __float__(self) -> float: + return float(self._decrypt()) + + def __int__(self) -> int: + return int(self._decrypt()) + + def __radd__(self, other: t.Any) -> str: + return other + self._decrypt() + + def __fspath__(self) -> str: + return self._decrypt() + + +class VaultHelper: + """Vault specific utility methods.""" + + @staticmethod + def get_ciphertext(value: t.Any, *, with_tags: bool) -> str | None: + """ + If the given value is an `EncryptedString`, `VaultExceptionMarker` or tagged with `VaultedValue`, return the ciphertext, otherwise return `None`. + Tags on the value other than `VaultedValue` will be included on the ciphertext if `with_tags` is `True`, otherwise it will be tagless. + """ + value_type = type(value) + ciphertext: str | None + tags = AnsibleTagHelper.tags(value) + + if value_type is _jinja_common.VaultExceptionMarker: + ciphertext = value._marker_undecryptable_ciphertext + tags = AnsibleTagHelper.tags(ciphertext) # ciphertext has tags but value does not + elif value_type is EncryptedString: + ciphertext = value._ciphertext + elif value_type in _jinja_common.Marker.concrete_subclasses: # avoid wasteful raise/except of Marker when calling get_tag below + ciphertext = None + elif vaulted_value := VaultedValue.get_tag(value): + ciphertext = vaulted_value.ciphertext + else: + ciphertext = None + + if ciphertext: + if with_tags: + ciphertext = VaultedValue.untag(AnsibleTagHelper.tag(ciphertext, tags)) + else: + ciphertext = AnsibleTagHelper.untag(ciphertext) + + return ciphertext diff --git a/lib/ansible/parsing/yaml/__init__.py b/lib/ansible/parsing/yaml/__init__.py index 64fee52484f..e69de29bb2d 100644 --- a/lib/ansible/parsing/yaml/__init__.py +++ b/lib/ansible/parsing/yaml/__init__.py @@ -1,18 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations diff --git a/lib/ansible/parsing/yaml/constructor.py b/lib/ansible/parsing/yaml/constructor.py deleted file mode 100644 index 300dad38ca9..00000000000 --- a/lib/ansible/parsing/yaml/constructor.py +++ /dev/null @@ -1,178 +0,0 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations - -from yaml.constructor import SafeConstructor, ConstructorError -from yaml.nodes import MappingNode - -from ansible import constants as C -from ansible.module_utils.common.text.converters import to_bytes, to_native -from ansible.parsing.yaml.objects import AnsibleMapping, AnsibleSequence, AnsibleUnicode, AnsibleVaultEncryptedUnicode -from ansible.parsing.vault import VaultLib -from ansible.utils.display import Display -from ansible.utils.unsafe_proxy import wrap_var - -display = Display() - - -class AnsibleConstructor(SafeConstructor): - def __init__(self, file_name=None, vault_secrets=None): - self._ansible_file_name = file_name - super(AnsibleConstructor, self).__init__() - self._vaults = {} - self.vault_secrets = vault_secrets or [] - self._vaults['default'] = VaultLib(secrets=self.vault_secrets) - - def construct_yaml_map(self, node): - data = AnsibleMapping() - yield data - value = self.construct_mapping(node) - data.update(value) - data.ansible_pos = self._node_position_info(node) - - def construct_mapping(self, node, deep=False): - # Most of this is from yaml.constructor.SafeConstructor. We replicate - # it here so that we can warn users when they have duplicate dict keys - # (pyyaml silently allows overwriting keys) - if not isinstance(node, MappingNode): - raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, - node.start_mark) - self.flatten_mapping(node) - mapping = AnsibleMapping() - - # Add our extra information to the returned value - mapping.ansible_pos = self._node_position_info(node) - - for key_node, value_node in node.value: - key = self.construct_object(key_node, deep=deep) - try: - hash(key) - except TypeError as exc: - raise ConstructorError("while constructing a mapping", node.start_mark, - "found unacceptable key (%s)" % exc, key_node.start_mark) - - if key in mapping: - msg = (u'While constructing a mapping from {1}, line {2}, column {3}, found a duplicate dict key ({0}).' - u' Using last defined value only.'.format(key, *mapping.ansible_pos)) - if C.DUPLICATE_YAML_DICT_KEY == 'warn': - display.warning(msg) - elif C.DUPLICATE_YAML_DICT_KEY == 'error': - raise ConstructorError(context=None, context_mark=None, - problem=to_native(msg), - problem_mark=node.start_mark, - note=None) - else: - # when 'ignore' - display.debug(msg) - - value = self.construct_object(value_node, deep=deep) - mapping[key] = value - - return mapping - - def construct_yaml_str(self, node): - # Override the default string handling function - # to always return unicode objects - value = self.construct_scalar(node) - ret = AnsibleUnicode(value) - - ret.ansible_pos = self._node_position_info(node) - - return ret - - def construct_vault_encrypted_unicode(self, node): - value = self.construct_scalar(node) - b_ciphertext_data = to_bytes(value) - # could pass in a key id here to choose the vault to associate with - # TODO/FIXME: plugin vault selector - vault = self._vaults['default'] - if vault.secrets is None: - raise ConstructorError(context=None, context_mark=None, - problem="found !vault but no vault password provided", - problem_mark=node.start_mark, - note=None) - ret = AnsibleVaultEncryptedUnicode(b_ciphertext_data) - ret.vault = vault - ret.ansible_pos = self._node_position_info(node) - return ret - - def construct_yaml_seq(self, node): - data = AnsibleSequence() - yield data - data.extend(self.construct_sequence(node)) - data.ansible_pos = self._node_position_info(node) - - def construct_yaml_unsafe(self, node): - try: - constructor = getattr(node, 'id', 'object') - if constructor is not None: - constructor = getattr(self, 'construct_%s' % constructor) - except AttributeError: - constructor = self.construct_object - - value = constructor(node) - - return wrap_var(value) - - def _node_position_info(self, node): - # the line number where the previous token has ended (plus empty lines) - # Add one so that the first line is line 1 rather than line 0 - column = node.start_mark.column + 1 - line = node.start_mark.line + 1 - - # in some cases, we may have pre-read the data and then - # passed it to the load() call for YAML, in which case we - # want to override the default datasource (which would be - # '') to the actual filename we read in - datasource = self._ansible_file_name or node.start_mark.name - - return (datasource, line, column) - - -AnsibleConstructor.add_constructor( - u'tag:yaml.org,2002:map', - AnsibleConstructor.construct_yaml_map) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'tag:yaml.org,2002:python/dict', - AnsibleConstructor.construct_yaml_map) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'tag:yaml.org,2002:str', - AnsibleConstructor.construct_yaml_str) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'tag:yaml.org,2002:python/unicode', - AnsibleConstructor.construct_yaml_str) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'tag:yaml.org,2002:seq', - AnsibleConstructor.construct_yaml_seq) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'!unsafe', - AnsibleConstructor.construct_yaml_unsafe) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'!vault', - AnsibleConstructor.construct_vault_encrypted_unicode) # type: ignore[type-var] - -AnsibleConstructor.add_constructor( - u'!vault-encrypted', - AnsibleConstructor.construct_vault_encrypted_unicode) # type: ignore[type-var] diff --git a/lib/ansible/parsing/yaml/dumper.py b/lib/ansible/parsing/yaml/dumper.py index 4888e4fd10c..c51ac605e3f 100644 --- a/lib/ansible/parsing/yaml/dumper.py +++ b/lib/ansible/parsing/yaml/dumper.py @@ -1,120 +1,10 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . +from __future__ import annotations as _annotations -from __future__ import annotations +import typing as _t -import yaml +from ansible._internal._yaml import _dumper -from ansible.module_utils.six import text_type, binary_type -from ansible.module_utils.common.yaml import SafeDumper -from ansible.parsing.yaml.objects import AnsibleUnicode, AnsibleSequence, AnsibleMapping, AnsibleVaultEncryptedUnicode -from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes, NativeJinjaUnsafeText, NativeJinjaText -from ansible.template import AnsibleUndefined -from ansible.vars.hostvars import HostVars, HostVarsVars -from ansible.vars.manager import VarsWithSources - -class AnsibleDumper(SafeDumper): - """ - A simple stub class that allows us to add representers - for our overridden object types. - """ - - -def represent_hostvars(self, data): - return self.represent_dict(dict(data)) - - -# Note: only want to represent the encrypted data -def represent_vault_encrypted_unicode(self, data): - return self.represent_scalar(u'!vault', data._ciphertext.decode(), style='|') - - -def represent_unicode(self, data): - return yaml.representer.SafeRepresenter.represent_str(self, text_type(data)) - - -def represent_binary(self, data): - return yaml.representer.SafeRepresenter.represent_binary(self, binary_type(data)) - - -def represent_undefined(self, data): - # Here bool will ensure _fail_with_undefined_error happens - # if the value is Undefined. - # This happens because Jinja sets __bool__ on StrictUndefined - return bool(data) - - -AnsibleDumper.add_representer( - AnsibleUnicode, - represent_unicode, -) - -AnsibleDumper.add_representer( - AnsibleUnsafeText, - represent_unicode, -) - -AnsibleDumper.add_representer( - AnsibleUnsafeBytes, - represent_binary, -) - -AnsibleDumper.add_representer( - HostVars, - represent_hostvars, -) - -AnsibleDumper.add_representer( - HostVarsVars, - represent_hostvars, -) - -AnsibleDumper.add_representer( - VarsWithSources, - represent_hostvars, -) - -AnsibleDumper.add_representer( - AnsibleSequence, - yaml.representer.SafeRepresenter.represent_list, -) - -AnsibleDumper.add_representer( - AnsibleMapping, - yaml.representer.SafeRepresenter.represent_dict, -) - -AnsibleDumper.add_representer( - AnsibleVaultEncryptedUnicode, - represent_vault_encrypted_unicode, -) - -AnsibleDumper.add_representer( - AnsibleUndefined, - represent_undefined, -) - -AnsibleDumper.add_representer( - NativeJinjaUnsafeText, - represent_unicode, -) - -AnsibleDumper.add_representer( - NativeJinjaText, - represent_unicode, -) +def AnsibleDumper(*args, **kwargs) -> _t.Any: + """Compatibility factory function; returns an Ansible YAML dumper instance.""" + return _dumper.AnsibleDumper(*args, **kwargs) diff --git a/lib/ansible/parsing/yaml/loader.py b/lib/ansible/parsing/yaml/loader.py index b9bd3e1c6e3..ee878b9fca1 100644 --- a/lib/ansible/parsing/yaml/loader.py +++ b/lib/ansible/parsing/yaml/loader.py @@ -1,43 +1,10 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . +from __future__ import annotations as _annotations -from __future__ import annotations +import typing as _t -from yaml.resolver import Resolver +from ansible._internal._yaml import _loader -from ansible.parsing.yaml.constructor import AnsibleConstructor -from ansible.module_utils.common.yaml import HAS_LIBYAML, Parser -if HAS_LIBYAML: - class AnsibleLoader(Parser, AnsibleConstructor, Resolver): # type: ignore[misc] # pylint: disable=inconsistent-mro - def __init__(self, stream, file_name=None, vault_secrets=None): - Parser.__init__(self, stream) - AnsibleConstructor.__init__(self, file_name=file_name, vault_secrets=vault_secrets) - Resolver.__init__(self) -else: - from yaml.composer import Composer - from yaml.reader import Reader - from yaml.scanner import Scanner - - class AnsibleLoader(Reader, Scanner, Parser, Composer, AnsibleConstructor, Resolver): # type: ignore[misc,no-redef] # pylint: disable=inconsistent-mro - def __init__(self, stream, file_name=None, vault_secrets=None): - Reader.__init__(self, stream) - Scanner.__init__(self) - Parser.__init__(self) - Composer.__init__(self) - AnsibleConstructor.__init__(self, file_name=file_name, vault_secrets=vault_secrets) - Resolver.__init__(self) +def AnsibleLoader(*args, **kwargs) -> _t.Any: + """Compatibility factory function; returns an Ansible YAML loader instance.""" + return _loader.AnsibleLoader(*args, **kwargs) diff --git a/lib/ansible/parsing/yaml/objects.py b/lib/ansible/parsing/yaml/objects.py index f3ebcb8fc07..d8d6a2a646d 100644 --- a/lib/ansible/parsing/yaml/objects.py +++ b/lib/ansible/parsing/yaml/objects.py @@ -1,359 +1,56 @@ -# (c) 2012-2014, Michael DeHaan -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . +"""Backwards compatibility types, which will be deprecated a future release. Do not use these in new code.""" -from __future__ import annotations +from __future__ import annotations as _annotations -import sys as _sys +import typing as _t -from collections.abc import Sequence +from ansible.module_utils._internal import _datatag +from ansible.module_utils.common.text import converters as _converters +from ansible.parsing import vault as _vault -from ansible.module_utils.six import text_type -from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native +class _AnsibleMapping(dict): + """Backwards compatibility type.""" -class AnsibleBaseYAMLObject(object): - """ - the base class used to sub-class python built-in objects - so that we can add attributes to them during yaml parsing + def __new__(cls, value): + return _datatag.AnsibleTagHelper.tag_copy(value, dict(value)) - """ - _data_source = None - _line_number = 0 - _column_number = 0 - def _get_ansible_position(self): - return (self._data_source, self._line_number, self._column_number) +class _AnsibleUnicode(str): + """Backwards compatibility type.""" - def _set_ansible_position(self, obj): - try: - (src, line, col) = obj - except (TypeError, ValueError): - raise AssertionError( - 'ansible_pos can only be set with a tuple/list ' - 'of three values: source, line number, column number' - ) - self._data_source = src - self._line_number = line - self._column_number = col + def __new__(cls, value): + return _datatag.AnsibleTagHelper.tag_copy(value, str(value)) - ansible_pos = property(_get_ansible_position, _set_ansible_position) +class _AnsibleSequence(list): + """Backwards compatibility type.""" -class AnsibleMapping(AnsibleBaseYAMLObject, dict): - """ sub class for dictionaries """ - pass + def __new__(cls, value): + return _datatag.AnsibleTagHelper.tag_copy(value, list(value)) -class AnsibleUnicode(AnsibleBaseYAMLObject, text_type): - """ sub class for unicode objects """ - pass +class _AnsibleVaultEncryptedUnicode: + """Backwards compatibility type.""" + def __new__(cls, ciphertext: str | bytes): + encrypted_string = _vault.EncryptedString(ciphertext=_converters.to_text(_datatag.AnsibleTagHelper.untag(ciphertext))) -class AnsibleSequence(AnsibleBaseYAMLObject, list): - """ sub class for lists """ - pass + return _datatag.AnsibleTagHelper.tag_copy(ciphertext, encrypted_string) -class AnsibleVaultEncryptedUnicode(Sequence, AnsibleBaseYAMLObject): - """Unicode like object that is not evaluated (decrypted) until it needs to be""" - __UNSAFE__ = True - __ENCRYPTED__ = True - yaml_tag = u'!vault' +def __getattr__(name: str) -> _t.Any: + """Inject import-time deprecation warnings.""" + if (value := globals().get(f'_{name}', None)) and name.startswith('Ansible'): + # deprecated: description='enable deprecation of everything in this module', core_version='2.23' + # from ansible.utils.display import Display + # + # Display().deprecated( + # msg=f"Importing {name!r} is deprecated.", + # help_text="Instances of this type cannot be created and will not be encountered.", + # version="2.27", + # ) - @classmethod - def from_plaintext(cls, seq, vault, secret): - if not vault: - raise vault.AnsibleVaultError('Error creating AnsibleVaultEncryptedUnicode, invalid vault (%s) provided' % vault) + return value - ciphertext = vault.encrypt(seq, secret) - avu = cls(ciphertext) - avu.vault = vault - return avu - - def __init__(self, ciphertext): - """A AnsibleUnicode with a Vault attribute that can decrypt it. - - ciphertext is a byte string (str on PY2, bytestring on PY3). - - The .data attribute is a property that returns the decrypted plaintext - of the ciphertext as a PY2 unicode or PY3 string object. - """ - super(AnsibleVaultEncryptedUnicode, self).__init__() - - # after construction, calling code has to set the .vault attribute to a vaultlib object - self.vault = None - self._ciphertext = to_bytes(ciphertext) - - @property - def data(self): - if not self.vault: - return to_text(self._ciphertext) - return to_text(self.vault.decrypt(self._ciphertext, obj=self)) - - @data.setter - def data(self, value): - self._ciphertext = to_bytes(value) - - def is_encrypted(self): - return self.vault and self.vault.is_encrypted(self._ciphertext) - - def __eq__(self, other): - if self.vault: - return other == self.data - return False - - def __ne__(self, other): - if self.vault: - return other != self.data - return True - - def __reversed__(self): - # This gets inherited from ``collections.Sequence`` which returns a generator - # make this act more like the string implementation - return to_text(self[::-1], errors='surrogate_or_strict') - - def __str__(self): - return to_native(self.data, errors='surrogate_or_strict') - - def __unicode__(self): - return to_text(self.data, errors='surrogate_or_strict') - - def encode(self, encoding=None, errors=None): - return to_bytes(self.data, encoding=encoding, errors=errors) - - # Methods below are a copy from ``collections.UserString`` - # Some are copied as is, where others are modified to not - # auto wrap with ``self.__class__`` - def __repr__(self): - return repr(self.data) - - def __int__(self, base=10): - return int(self.data, base=base) - - def __float__(self): - return float(self.data) - - def __complex__(self): - return complex(self.data) - - def __hash__(self): - return hash(self.data) - - # This breaks vault, do not define it, we cannot satisfy this - # def __getnewargs__(self): - # return (self.data[:],) - - def __lt__(self, string): - if isinstance(string, AnsibleVaultEncryptedUnicode): - return self.data < string.data - return self.data < string - - def __le__(self, string): - if isinstance(string, AnsibleVaultEncryptedUnicode): - return self.data <= string.data - return self.data <= string - - def __gt__(self, string): - if isinstance(string, AnsibleVaultEncryptedUnicode): - return self.data > string.data - return self.data > string - - def __ge__(self, string): - if isinstance(string, AnsibleVaultEncryptedUnicode): - return self.data >= string.data - return self.data >= string - - def __contains__(self, char): - if isinstance(char, AnsibleVaultEncryptedUnicode): - char = char.data - return char in self.data - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return self.data[index] - - def __getslice__(self, start, end): - start = max(start, 0) - end = max(end, 0) - return self.data[start:end] - - def __add__(self, other): - if isinstance(other, AnsibleVaultEncryptedUnicode): - return self.data + other.data - elif isinstance(other, text_type): - return self.data + other - return self.data + to_text(other) - - def __radd__(self, other): - if isinstance(other, text_type): - return other + self.data - return to_text(other) + self.data - - def __mul__(self, n): - return self.data * n - - __rmul__ = __mul__ - - def __mod__(self, args): - return self.data % args - - def __rmod__(self, template): - return to_text(template) % self - - # the following methods are defined in alphabetical order: - def capitalize(self): - return self.data.capitalize() - - def casefold(self): - return self.data.casefold() - - def center(self, width, *args): - return self.data.center(width, *args) - - def count(self, sub, start=0, end=_sys.maxsize): - if isinstance(sub, AnsibleVaultEncryptedUnicode): - sub = sub.data - return self.data.count(sub, start, end) - - def endswith(self, suffix, start=0, end=_sys.maxsize): - return self.data.endswith(suffix, start, end) - - def expandtabs(self, tabsize=8): - return self.data.expandtabs(tabsize) - - def find(self, sub, start=0, end=_sys.maxsize): - if isinstance(sub, AnsibleVaultEncryptedUnicode): - sub = sub.data - return self.data.find(sub, start, end) - - def format(self, *args, **kwds): - return self.data.format(*args, **kwds) - - def format_map(self, mapping): - return self.data.format_map(mapping) - - def index(self, sub, start=0, end=_sys.maxsize): - return self.data.index(sub, start, end) - - def isalpha(self): - return self.data.isalpha() - - def isalnum(self): - return self.data.isalnum() - - def isascii(self): - return self.data.isascii() - - def isdecimal(self): - return self.data.isdecimal() - - def isdigit(self): - return self.data.isdigit() - - def isidentifier(self): - return self.data.isidentifier() - - def islower(self): - return self.data.islower() - - def isnumeric(self): - return self.data.isnumeric() - - def isprintable(self): - return self.data.isprintable() - - def isspace(self): - return self.data.isspace() - - def istitle(self): - return self.data.istitle() - - def isupper(self): - return self.data.isupper() - - def join(self, seq): - return self.data.join(seq) - - def ljust(self, width, *args): - return self.data.ljust(width, *args) - - def lower(self): - return self.data.lower() - - def lstrip(self, chars=None): - return self.data.lstrip(chars) - - maketrans = str.maketrans - - def partition(self, sep): - return self.data.partition(sep) - - def replace(self, old, new, maxsplit=-1): - if isinstance(old, AnsibleVaultEncryptedUnicode): - old = old.data - if isinstance(new, AnsibleVaultEncryptedUnicode): - new = new.data - return self.data.replace(old, new, maxsplit) - - def rfind(self, sub, start=0, end=_sys.maxsize): - if isinstance(sub, AnsibleVaultEncryptedUnicode): - sub = sub.data - return self.data.rfind(sub, start, end) - - def rindex(self, sub, start=0, end=_sys.maxsize): - return self.data.rindex(sub, start, end) - - def rjust(self, width, *args): - return self.data.rjust(width, *args) - - def rpartition(self, sep): - return self.data.rpartition(sep) - - def rstrip(self, chars=None): - return self.data.rstrip(chars) - - def split(self, sep=None, maxsplit=-1): - return self.data.split(sep, maxsplit) - - def rsplit(self, sep=None, maxsplit=-1): - return self.data.rsplit(sep, maxsplit) - - def splitlines(self, keepends=False): - return self.data.splitlines(keepends) - - def startswith(self, prefix, start=0, end=_sys.maxsize): - return self.data.startswith(prefix, start, end) - - def strip(self, chars=None): - return self.data.strip(chars) - - def swapcase(self): - return self.data.swapcase() - - def title(self): - return self.data.title() - - def translate(self, *args): - return self.data.translate(*args) - - def upper(self): - return self.data.upper() - - def zfill(self, width): - return self.data.zfill(width) + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/lib/ansible/playbook/__init__.py b/lib/ansible/playbook/__init__.py index e125df1ba9a..3f28654cced 100644 --- a/lib/ansible/playbook/__init__.py +++ b/lib/ansible/playbook/__init__.py @@ -66,7 +66,7 @@ class Playbook: self._file_name = file_name try: - ds = self._loader.load_from_file(os.path.basename(file_name)) + ds = self._loader.load_from_file(os.path.basename(file_name), trusted_as_template=True) except UnicodeDecodeError as e: raise AnsibleParserError("Could not read playbook (%s) due to encoding issues: %s" % (file_name, to_native(e))) diff --git a/lib/ansible/playbook/attribute.py b/lib/ansible/playbook/attribute.py index ee797c27ef4..3dbbef555ba 100644 --- a/lib/ansible/playbook/attribute.py +++ b/lib/ansible/playbook/attribute.py @@ -17,7 +17,12 @@ from __future__ import annotations -from ansible.module_utils.common.sentinel import Sentinel +import typing as t + +from ansible.utils.sentinel import Sentinel + +if t.TYPE_CHECKING: + from ansible.playbook.base import FieldAttributeBase _CONTAINERS = frozenset(('list', 'dict', 'set')) @@ -105,7 +110,7 @@ class Attribute: def __ge__(self, other): return other.priority >= self.priority - def __get__(self, obj, obj_type=None): + def __get__(self, obj: FieldAttributeBase, obj_type=None): method = f'_get_attr_{self.name}' if hasattr(obj, method): # NOTE this appears to be not used in the codebase, @@ -127,7 +132,7 @@ class Attribute: return value - def __set__(self, obj, value): + def __set__(self, obj: FieldAttributeBase, value): setattr(obj, f'_{self.name}', value) if self.alias is not None: setattr(obj, f'_{self.alias}', value) @@ -180,7 +185,7 @@ class FieldAttribute(Attribute): class ConnectionFieldAttribute(FieldAttribute): def __get__(self, obj, obj_type=None): - from ansible.module_utils.compat.paramiko import paramiko + from ansible.module_utils.compat.paramiko import _paramiko as paramiko from ansible.utils.ssh_functions import check_for_controlpersist value = super().__get__(obj, obj_type) diff --git a/lib/ansible/playbook/base.py b/lib/ansible/playbook/base.py index a762548fddf..890401654d5 100644 --- a/lib/ansible/playbook/base.py +++ b/lib/ansible/playbook/base.py @@ -9,14 +9,16 @@ import itertools import operator import os +import typing as t + from copy import copy as shallowcopy from functools import cache -from jinja2.exceptions import UndefinedError - from ansible import constants as C from ansible import context -from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError +from ansible.errors import AnsibleError, AnsibleParserError, AnsibleAssertionError, AnsibleValueOmittedError, AnsibleFieldAttributeError +from ansible.module_utils.datatag import native_type_name +from ansible._internal._datatag._tags import Origin from ansible.module_utils.six import string_types from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.common.sentinel import Sentinel @@ -26,7 +28,8 @@ from ansible.playbook.attribute import Attribute, FieldAttribute, ConnectionFiel from ansible.plugins.loader import module_loader, action_loader from ansible.utils.collection_loader._collection_finder import _get_collection_metadata, AnsibleCollectionRef from ansible.utils.display import Display -from ansible.utils.vars import combine_vars, isidentifier, get_unique_id +from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name +from ansible._internal._templating._engine import TemplateEngine display = Display() @@ -96,12 +99,13 @@ class FieldAttributeBase: fattributes[attr.alias] = attr return fattributes - def __init__(self): + def __init__(self) -> None: # initialize the data loader and variable manager, which will be provided # later when the object is actually loaded self._loader = None self._variable_manager = None + self._origin: Origin | None = None # other internal params self._validated = False @@ -111,9 +115,6 @@ class FieldAttributeBase: # every object gets a random uuid: self._uuid = get_unique_id() - # init vars, avoid using defaults in field declaration as it lives across plays - self.vars = dict() - @property def finalized(self): return self._finalized @@ -148,6 +149,7 @@ class FieldAttributeBase: # the variable manager class is used to manage and merge variables # down to a single dictionary for reference in templating, etc. self._variable_manager = variable_manager + self._origin = Origin.get_tag(ds) # the data loader class is used to parse data from strings and files if loader is not None: @@ -191,7 +193,11 @@ class FieldAttributeBase: return self._variable_manager def _post_validate_debugger(self, attr, value, templar): - value = templar.template(value) + try: + value = templar.template(value) + except AnsibleValueOmittedError: + value = self.set_to_context(attr.name) + valid_values = frozenset(('always', 'on_failed', 'on_unreachable', 'on_skipped', 'never')) if value and isinstance(value, string_types) and value not in valid_values: raise AnsibleParserError("'%s' is not a valid value for debugger. Must be one of %s" % (value, ', '.join(valid_values)), obj=self.get_ds()) @@ -206,7 +212,7 @@ class FieldAttributeBase: valid_attrs = frozenset(self.fattributes) for key in ds: if key not in valid_attrs: - raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds) + raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=key) def validate(self, all_vars=None): """ validation that is done at parse time, not load time """ @@ -244,7 +250,8 @@ class FieldAttributeBase: raise AnsibleParserError( "The field 'module_defaults' is supposed to be a dictionary or list of dictionaries, " "the keys of which must be static action, module, or group names. Only the values may contain " - "templates. For example: {'ping': \"{{ ping_defaults }}\"}" + "templates. For example: {'ping': \"{{ ping_defaults }}\"}", + obj=defaults_dict, ) validated_defaults_dict = {} @@ -419,14 +426,15 @@ class FieldAttributeBase: try: new_me = self.__class__() - except RuntimeError as e: - raise AnsibleError("Exceeded maximum object depth. This may have been caused by excessive role recursion", orig_exc=e) + except RecursionError as ex: + raise AnsibleError("Exceeded maximum object depth. This may have been caused by excessive role recursion.") from ex for name in self.fattributes: setattr(new_me, name, shallowcopy(getattr(self, f'_{name}', Sentinel))) new_me._loader = self._loader new_me._variable_manager = self._variable_manager + new_me._origin = self._origin new_me._validated = self._validated new_me._finalized = self._finalized new_me._uuid = self._uuid @@ -438,6 +446,12 @@ class FieldAttributeBase: return new_me def get_validated_value(self, name, attribute, value, templar): + try: + return self._get_validated_value(name, attribute, value, templar) + except (TypeError, ValueError): + raise AnsibleError(f"The value {value!r} could not be converted to {attribute.isa!r}.", obj=value) + + def _get_validated_value(self, name, attribute, value, templar): if attribute.isa == 'string': value = to_text(value) elif attribute.isa == 'int': @@ -466,28 +480,23 @@ class FieldAttributeBase: if attribute.listof is not None: for item in value: if not isinstance(item, attribute.listof): - raise AnsibleParserError("the field '%s' should be a list of %s, " - "but the item '%s' is a %s" % (name, attribute.listof, item, type(item)), obj=self.get_ds()) - elif attribute.required and attribute.listof == string_types: + type_names = ' or '.join(f'{native_type_name(attribute_type)!r}' for attribute_type in attribute.listof) + + raise AnsibleParserError( + message=f"Keyword {name!r} items must be of type {type_names}, not {native_type_name(item)!r}.", + obj=Origin.first_tagged_on(item, value, self.get_ds()), + ) + elif attribute.required and attribute.listof == (str,): if item is None or item.strip() == "": - raise AnsibleParserError("the field '%s' is required, and cannot have empty values" % (name,), obj=self.get_ds()) - elif attribute.isa == 'set': - if value is None: - value = set() - elif not isinstance(value, (list, set)): - if isinstance(value, string_types): - value = value.split(',') - else: - # Making a list like this handles strings of - # text and bytes properly - value = [value] - if not isinstance(value, set): - value = set(value) + raise AnsibleParserError( + message=f"Keyword {name!r} is required, and cannot have empty values.", + obj=Origin.first_tagged_on(item, value, self.get_ds()), + ) elif attribute.isa == 'dict': if value is None: value = dict() elif not isinstance(value, dict): - raise TypeError("%s is not a dictionary" % value) + raise AnsibleError(f"{value!r} is not a dictionary") elif attribute.isa == 'class': if not isinstance(value, attribute.class_type): raise TypeError("%s is not a valid %s (got a %s instead)" % (name, attribute.class_type, type(value))) @@ -496,19 +505,22 @@ class FieldAttributeBase: raise AnsibleAssertionError(f"Unknown value for attribute.isa: {attribute.isa}") return value - def set_to_context(self, name): + def set_to_context(self, name: str) -> t.Any: """ set to parent inherited value or Sentinel as appropriate""" attribute = self.fattributes[name] if isinstance(attribute, NonInheritableFieldAttribute): # setting to sentinel will trigger 'default/default()' on getter - setattr(self, name, Sentinel) + value = Sentinel else: try: - setattr(self, name, self._get_parent_attribute(name, omit=True)) + value = self._get_parent_attribute(name, omit=True) except AttributeError: # mostly playcontext as only tasks/handlers/blocks really resolve parent - setattr(self, name, Sentinel) + value = Sentinel + + setattr(self, name, value) + return value def post_validate(self, templar): """ @@ -517,91 +529,101 @@ class FieldAttributeBase: any _post_validate_ functions. """ - # save the omit value for later checking - omit_value = templar.available_variables.get('omit') + for name in self.fattributes: + value = self.post_validate_attribute(name, templar=templar) - for (name, attribute) in self.fattributes.items(): - if attribute.static: - value = getattr(self, name) + if value is not Sentinel: + # and assign the massaged value back to the attribute field + setattr(self, name, value) - # we don't template 'vars' but allow template as values for later use - if name not in ('vars',) and templar.is_template(value): - display.warning('"%s" is not templatable, but we found: %s, ' - 'it will not be templated and will be used "as is".' % (name, value)) - continue + self._finalized = True - if getattr(self, name) is None: - if not attribute.required: - continue - else: - raise AnsibleParserError("the field '%s' is required but was not set" % name) - elif not attribute.always_post_validate and self.__class__.__name__ not in ('Task', 'Handler', 'PlayContext'): - # Intermediate objects like Play() won't have their fields validated by - # default, as their values are often inherited by other objects and validated - # later, so we don't want them to fail out early - continue + def post_validate_attribute(self, name: str, *, templar: TemplateEngine): + attribute: FieldAttribute = self.fattributes[name] - try: - # Run the post-validator if present. These methods are responsible for - # using the given templar to template the values, if required. - method = getattr(self, '_post_validate_%s' % name, None) - if method: - value = method(attribute, getattr(self, name), templar) - elif attribute.isa == 'class': - value = getattr(self, name) - else: + # DTFIX-FUTURE: this can probably be used in many getattr cases below, but the value may be out-of-date in some cases + original_value = getattr(self, name) # we save this original (likely Origin-tagged) value to pass as `obj` for errors + + if attribute.static: + value = getattr(self, name) + + # we don't template 'vars' but allow template as values for later use + if name not in ('vars',) and templar.is_template(value): + display.warning('"%s" is not templatable, but we found: %s, ' + 'it will not be templated and will be used "as is".' % (name, value)) + return Sentinel + + if getattr(self, name) is None: + if not attribute.required: + return Sentinel + + raise AnsibleFieldAttributeError(f'The field {name!r} is required but was not set.', obj=self.get_ds()) + + from .role_include import IncludeRole + + if not attribute.always_post_validate and isinstance(self, IncludeRole) and self.statically_loaded: # import_role + # normal field attributes should not go through post validation on import_role/import_tasks + # only import_role is checked here because import_tasks never reaches this point + return Sentinel + + # FIXME: compare types, not strings + if not attribute.always_post_validate and self.__class__.__name__ not in ('Task', 'Handler', 'PlayContext', 'IncludeRole', 'TaskInclude'): + # Intermediate objects like Play() won't have their fields validated by + # default, as their values are often inherited by other objects and validated + # later, so we don't want them to fail out early + return Sentinel + + try: + # Run the post-validator if present. These methods are responsible for + # using the given templar to template the values, if required. + method = getattr(self, '_post_validate_%s' % name, None) + + if method: + value = method(attribute, getattr(self, name), templar) + elif attribute.isa == 'class': + value = getattr(self, name) + else: + try: # if the attribute contains a variable, template it now value = templar.template(getattr(self, name)) + except AnsibleValueOmittedError: + # If this evaluated to the omit value, set the value back to inherited by context + # or default specified in the FieldAttribute and move on + value = self.set_to_context(name) - # If this evaluated to the omit value, set the value back to inherited by context - # or default specified in the FieldAttribute and move on - if omit_value is not None and value == omit_value: - self.set_to_context(name) - continue + if value is Sentinel: + return value - # and make sure the attribute is of the type it should be - if value is not None: - value = self.get_validated_value(name, attribute, value, templar) + # and make sure the attribute is of the type it should be + if value is not None: + value = self.get_validated_value(name, attribute, value, templar) - # and assign the massaged value back to the attribute field - setattr(self, name, value) - except (TypeError, ValueError) as e: - value = getattr(self, name) - raise AnsibleParserError(f"the field '{name}' has an invalid value ({value!r}), and could not be converted to {attribute.isa}.", - obj=self.get_ds(), orig_exc=e) - except (AnsibleUndefinedVariable, UndefinedError) as e: - if templar._fail_on_undefined_errors and name != 'name': - if name == 'args': - msg = "The task includes an option with an undefined variable." - else: - msg = f"The field '{name}' has an invalid value, which includes an undefined variable." - raise AnsibleParserError(msg, obj=self.get_ds(), orig_exc=e) + # returning the value results in assigning the massaged value back to the attribute field + return value + except Exception as ex: + if name == 'args': + raise # no useful information to contribute, raise the original exception - self._finalized = True + raise AnsibleFieldAttributeError(f'Error processing keyword {name!r}.', obj=original_value) from ex def _load_vars(self, attr, ds): """ Vars in a play must be specified as a dictionary. """ - def _validate_variable_keys(ds): - for key in ds: - if not isidentifier(key): - raise TypeError("'%s' is not a valid variable name" % key) - try: if isinstance(ds, dict): - _validate_variable_keys(ds) + for key in ds: + validate_variable_name(key) return combine_vars(self.vars, ds) elif ds is None: return {} else: raise ValueError - except ValueError as e: - raise AnsibleParserError("Vars in a %s must be specified as a dictionary" % self.__class__.__name__, - obj=ds, orig_exc=e) - except TypeError as e: - raise AnsibleParserError("Invalid variable name in vars specified for %s: %s" % (self.__class__.__name__, e), obj=ds, orig_exc=e) + except ValueError as ex: + raise AnsibleParserError(f"Vars in a {self.__class__.__name__} must be specified as a dictionary.", obj=ds) from ex + except TypeError as ex: + raise AnsibleParserError(f"Invalid variable name in vars specified for {self.__class__.__name__}.", obj=ds) from ex def _extend_value(self, value, new_value, prepend=False): """ @@ -654,6 +676,8 @@ class FieldAttributeBase: setattr(self, attr, obj) else: setattr(self, attr, value) + else: + setattr(self, attr, value) # overridden dump_attrs in derived types may dump attributes which are not field attributes # from_attrs is only used to create a finalized task # from attrs from the Worker/TaskExecutor @@ -713,7 +737,7 @@ class Base(FieldAttributeBase): remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user')) # variables - vars = NonInheritableFieldAttribute(isa='dict', priority=100, static=True) + vars = NonInheritableFieldAttribute(isa='dict', priority=100, static=True, default=dict) # module default params module_defaults = FieldAttribute(isa='list', extend=True, prepend=True) @@ -743,17 +767,43 @@ class Base(FieldAttributeBase): # used to hold sudo/su stuff DEPRECATED_ATTRIBUTES = [] # type: list[str] - def get_path(self): + def update_result_no_log(self, templar: TemplateEngine, result: dict[str, t.Any]) -> None: + """Set the post-validated no_log value for the result, falling back to a default on validation/templating failure with a warning.""" + + if self.finalized: + no_log = self.no_log + else: + try: + no_log = self.post_validate_attribute('no_log', templar=templar) + except Exception as ex: + display.error_as_warning('Invalid no_log value for task, output will be masked.', exception=ex) + no_log = True + + result_no_log = result.get('_ansible_no_log', False) + + if not isinstance(result_no_log, bool): + display.warning(f'Invalid _ansible_no_log value of type {type(result_no_log).__name__!r} in task result, output will be masked.') + no_log = True + + no_log = no_log or result_no_log + + result.update(_ansible_no_log=no_log) + + def get_path(self) -> str: """ return the absolute path of the playbook object and its line number """ + origin = self._origin - path = "" - try: - path = "%s:%s" % (self._ds._data_source, self._ds._line_number) - except AttributeError: + if not origin: try: - path = "%s:%s" % (self._parent._play._ds._data_source, self._parent._play._ds._line_number) + origin = self._parent._play._origin except AttributeError: pass + + if origin and origin.path: + path = f"{origin.path}:{origin.line_num or 1}" + else: + path = "" + return path def get_dep_chain(self): diff --git a/lib/ansible/playbook/block.py b/lib/ansible/playbook/block.py index 464ff3879c5..a47bdc31e45 100644 --- a/lib/ansible/playbook/block.py +++ b/lib/ansible/playbook/block.py @@ -113,6 +113,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab return super(Block, self).preprocess_data(ds) + # FIXME: these do nothing but augment the exception message; DRY and nuke + def _load_block(self, attr, ds): try: return load_list_of_tasks( @@ -125,8 +127,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab loader=self._loader, use_handlers=self._use_handlers, ) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading a block", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading a block", obj=self._ds) from ex def _load_rescue(self, attr, ds): try: @@ -140,8 +142,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab loader=self._loader, use_handlers=self._use_handlers, ) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading rescue.", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading rescue.", obj=self._ds) from ex def _load_always(self, attr, ds): try: @@ -155,8 +157,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab loader=self._loader, use_handlers=self._use_handlers, ) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading always", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading always", obj=self._ds) from ex def _validate_always(self, attr, name, value): if value and not self.block: diff --git a/lib/ansible/playbook/collectionsearch.py b/lib/ansible/playbook/collectionsearch.py index c6ab50907bf..d5bc9450ef2 100644 --- a/lib/ansible/playbook/collectionsearch.py +++ b/lib/ansible/playbook/collectionsearch.py @@ -6,11 +6,8 @@ from __future__ import annotations from ansible.module_utils.six import string_types from ansible.playbook.attribute import FieldAttribute from ansible.utils.collection_loader import AnsibleCollectionConfig -from ansible.template import is_template from ansible.utils.display import Display -from jinja2.nativetypes import NativeEnvironment - display = Display() @@ -35,8 +32,7 @@ def _ensure_default_collection(collection_list=None): class CollectionSearch: # this needs to be populated before we can resolve tasks/roles/etc - collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection, - always_post_validate=True, static=True) + collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection, always_post_validate=True, static=True) def _load_collections(self, attr, ds): # We are always a mixin with Base, so we can validate this untemplated @@ -49,14 +45,4 @@ class CollectionSearch: if not ds: # don't return an empty collection list, just return None return None - # This duplicates static attr checking logic from post_validate() - # because if the user attempts to template a collection name, it may - # error before it ever gets to the post_validate() warning (e.g. trying - # to import a role from the collection). - env = NativeEnvironment() - for collection_name in ds: - if is_template(collection_name, env): - display.warning('"collections" is not templatable, but we found: %s, ' - 'it will not be templated and will be used "as is".' % (collection_name)) - return ds diff --git a/lib/ansible/playbook/conditional.py b/lib/ansible/playbook/conditional.py index 21a9cf4c17c..ac59259acb3 100644 --- a/lib/ansible/playbook/conditional.py +++ b/lib/ansible/playbook/conditional.py @@ -17,12 +17,7 @@ from __future__ import annotations -import typing as t - -from ansible.errors import AnsibleError, AnsibleUndefinedVariable -from ansible.module_utils.common.text.converters import to_native from ansible.playbook.attribute import FieldAttribute -from ansible.template import Templar from ansible.utils.display import Display display = Display() @@ -36,78 +31,9 @@ class Conditional: when = FieldAttribute(isa='list', default=list, extend=True, prepend=True) - def __init__(self, loader=None): - # when used directly, this class needs a loader, but we want to - # make sure we don't trample on the existing one if this class - # is used as a mix-in with a playbook base class - if not hasattr(self, '_loader'): - if loader is None: - raise AnsibleError("a loader must be specified when using Conditional() directly") - else: - self._loader = loader + def __init__(self, *args, **kwargs): super().__init__() def _validate_when(self, attr, name, value): if not isinstance(value, list): setattr(self, name, [value]) - - def evaluate_conditional(self, templar: Templar, all_vars: dict[str, t.Any]) -> bool: - """ - Loops through the conditionals set on this object, returning - False if any of them evaluate as such. - """ - return self.evaluate_conditional_with_result(templar, all_vars)[0] - - def evaluate_conditional_with_result(self, templar: Templar, all_vars: dict[str, t.Any]) -> tuple[bool, t.Optional[str]]: - """Loops through the conditionals set on this object, returning - False if any of them evaluate as such as well as the condition - that was false. - """ - for conditional in self.when: - if conditional is None or conditional == "": - res = True - elif isinstance(conditional, bool): - res = conditional - else: - try: - res = self._check_conditional(conditional, templar, all_vars) - except AnsibleError as e: - raise AnsibleError( - "The conditional check '%s' failed. The error was: %s" % (to_native(conditional), to_native(e)), - obj=getattr(self, '_ds', None) - ) - - display.debug("Evaluated conditional (%s): %s" % (conditional, res)) - if not res: - return res, conditional - - return True, None - - def _check_conditional(self, conditional: str, templar: Templar, all_vars: dict[str, t.Any]) -> bool: - original = conditional - templar.available_variables = all_vars - try: - if templar.is_template(conditional): - display.warning( - "conditional statements should not include jinja2 " - "templating delimiters such as {{ }} or {%% %%}. " - "Found: %s" % conditional - ) - conditional = templar.template(conditional) - if isinstance(conditional, bool): - return conditional - elif conditional == "": - return False - - # If the result of the first-pass template render (to resolve inline templates) is marked unsafe, - # explicitly disable lookups on the final pass to prevent evaluation of untrusted content in the - # constructed template. - disable_lookups = hasattr(conditional, '__UNSAFE__') - - # NOTE The spaces around True and False are intentional to short-circuit literal_eval for - # jinja2_native=False and avoid its expensive calls. - return templar.template( - "{%% if %s %%} True {%% else %%} False {%% endif %%}" % conditional, - disable_lookups=disable_lookups).strip() == "True" - except AnsibleUndefinedVariable as e: - raise AnsibleUndefinedVariable("error while evaluating conditional (%s): %s" % (original, e)) diff --git a/lib/ansible/playbook/helpers.py b/lib/ansible/playbook/helpers.py index 6686d4f2423..f700bb2349a 100644 --- a/lib/ansible/playbook/helpers.py +++ b/lib/ansible/playbook/helpers.py @@ -21,9 +21,9 @@ import os from ansible import constants as C from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError -from ansible.module_utils.common.text.converters import to_native from ansible.parsing.mod_args import ModuleArgsParser from ansible.utils.display import Display +from ansible._internal._templating._engine import TemplateEngine display = Display() @@ -92,7 +92,6 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h from ansible.playbook.task_include import TaskInclude from ansible.playbook.role_include import IncludeRole from ansible.playbook.handler_task_include import HandlerTaskInclude - from ansible.template import Templar if not isinstance(ds, list): raise AnsibleAssertionError('The ds (%s) should be a list but was a %s' % (ds, type(ds))) @@ -105,7 +104,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h if 'block' in task_ds: if use_handlers: raise AnsibleParserError("Using a block as a handler is not supported.", obj=task_ds) - t = Block.load( + task = Block.load( task_ds, play=play, parent_block=block, @@ -115,18 +114,20 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h variable_manager=variable_manager, loader=loader, ) - task_list.append(t) + task_list.append(task) else: args_parser = ModuleArgsParser(task_ds) try: (action, args, delegate_to) = args_parser.parse(skip_action_validation=True) - except AnsibleParserError as e: + except AnsibleParserError as ex: # if the raises exception was created with obj=ds args, then it includes the detail # so we dont need to add it so we can just re raise. - if e.obj: + if ex.obj: raise # But if it wasn't, we can add the yaml object now to get more detail - raise AnsibleParserError(to_native(e), obj=task_ds, orig_exc=e) + # DTFIX-FUTURE: this *should* be unnecessary- check code coverage. + # Will definitely be unnecessary once we have proper contexts to consult. + raise AnsibleParserError("Error loading tasks.", obj=task_ds) from ex if action in C._ACTION_ALL_INCLUDE_IMPORT_TASKS: @@ -135,7 +136,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h else: include_class = TaskInclude - t = include_class.load( + task = include_class.load( task_ds, block=block, role=role, @@ -144,16 +145,16 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h loader=loader ) - all_vars = variable_manager.get_vars(play=play, task=t) - templar = Templar(loader=loader, variables=all_vars) + all_vars = variable_manager.get_vars(play=play, task=task) + templar = TemplateEngine(loader=loader, variables=all_vars) # check to see if this include is dynamic or static: if action in C._ACTION_IMPORT_TASKS: - if t.loop is not None: + if task.loop is not None: raise AnsibleParserError("You cannot use loops on 'import_tasks' statements. You should use 'include_tasks' instead.", obj=task_ds) # we set a flag to indicate this include was static - t.statically_loaded = True + task.statically_loaded = True # handle relative includes by walking up the list of parent include # tasks and checking the relative result to see if it exists @@ -168,26 +169,14 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h if not isinstance(parent_include, TaskInclude): parent_include = parent_include._parent continue - try: - parent_include_dir = os.path.dirname(templar.template(parent_include.args.get('_raw_params'))) - except AnsibleUndefinedVariable as e: - if not parent_include.statically_loaded: - raise AnsibleParserError( - "Error when evaluating variable in dynamic parent include path: %s. " - "When using static imports, the parent dynamic include cannot utilize host facts " - "or variables from inventory" % parent_include.args.get('_raw_params'), - obj=task_ds, - suppress_extended_error=True, - orig_exc=e - ) - raise + parent_include_dir = os.path.dirname(parent_include.args.get('_raw_params')) if cumulative_path is None: cumulative_path = parent_include_dir elif not os.path.isabs(cumulative_path): cumulative_path = os.path.join(parent_include_dir, cumulative_path) - include_target = templar.template(t.args['_raw_params']) - if t._role: - new_basedir = os.path.join(t._role._role_path, subdir, cumulative_path) + include_target = templar.template(task.args['_raw_params']) + if task._role: + new_basedir = os.path.join(task._role._role_path, subdir, cumulative_path) include_file = loader.path_dwim_relative(new_basedir, subdir, include_target) else: include_file = loader.path_dwim_relative(loader.get_basedir(), cumulative_path, include_target) @@ -200,22 +189,21 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h if not found: try: - include_target = templar.template(t.args['_raw_params']) - except AnsibleUndefinedVariable as e: + include_target = templar.template(task.args['_raw_params']) + except AnsibleUndefinedVariable as ex: raise AnsibleParserError( - "Error when evaluating variable in import path: %s.\n\n" - "When using static imports, ensure that any variables used in their names are defined in vars/vars_files\n" + message=f"Error when evaluating variable in import path {task.args['_raw_params']!r}.", + help_text="When using static imports, ensure that any variables used in their names are defined in vars/vars_files\n" "or extra-vars passed in from the command line. Static imports cannot use variables from facts or inventory\n" - "sources like group or host vars." % t.args['_raw_params'], + "sources like group or host vars.", obj=task_ds, - suppress_extended_error=True, - orig_exc=e) - if t._role: - include_file = loader.path_dwim_relative(t._role._role_path, subdir, include_target) + ) from ex + if task._role: + include_file = loader.path_dwim_relative(task._role._role_path, subdir, include_target) else: include_file = loader.path_dwim(include_target) - data = loader.load_from_file(include_file) + data = loader.load_from_file(include_file, trusted_as_template=True) if not data: display.warning('file %s is empty and had no tasks to include' % include_file) continue @@ -228,7 +216,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h # nested includes, and we want the include order printed correctly display.vv("statically imported: %s" % include_file) - ti_copy = t.copy(exclude_parent=True) + ti_copy = task.copy(exclude_parent=True) ti_copy._parent = block included_blocks = load_list_of_blocks( data, @@ -246,7 +234,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h # now we extend the tags on each of the included blocks for b in included_blocks: b.tags = list(set(b.tags).union(tags)) - # END FIXME + # FIXME - END # FIXME: handlers shouldn't need this special handling, but do # right now because they don't iterate blocks correctly @@ -256,7 +244,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h else: task_list.extend(included_blocks) else: - task_list.append(t) + task_list.append(task) elif action in C._ACTION_ALL_PROPER_INCLUDE_IMPORT_ROLES: if use_handlers: @@ -280,7 +268,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h # template the role name now, if needed all_vars = variable_manager.get_vars(play=play, task=ir) - templar = Templar(loader=loader, variables=all_vars) + templar = TemplateEngine(loader=loader, variables=all_vars) ir.post_validate(templar=templar) ir._role_name = templar.template(ir._role_name) @@ -292,15 +280,15 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h task_list.append(ir) else: if use_handlers: - t = Handler.load(task_ds, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader) - if t.action in C._ACTION_META and t.args.get('_raw_params') == "end_role": - raise AnsibleParserError("Cannot execute 'end_role' from a handler") + task = Handler.load(task_ds, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader) + if task._get_meta() == "end_role": + raise AnsibleParserError("Cannot execute 'end_role' from a handler", obj=task) else: - t = Task.load(task_ds, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader) - if t.action in C._ACTION_META and t.args.get('_raw_params') == "end_role" and role is None: - raise AnsibleParserError("Cannot execute 'end_role' from outside of a role") + task = Task.load(task_ds, block=block, role=role, task_include=task_include, variable_manager=variable_manager, loader=loader) + if task._get_meta() == "end_role" and role is None: + raise AnsibleParserError("Cannot execute 'end_role' from outside of a role", obj=task) - task_list.append(t) + task_list.append(task) return task_list diff --git a/lib/ansible/playbook/included_file.py b/lib/ansible/playbook/included_file.py index d2fdb76364d..673f5cfd71f 100644 --- a/lib/ansible/playbook/included_file.py +++ b/lib/ansible/playbook/included_file.py @@ -21,12 +21,11 @@ import os from ansible import constants as C from ansible.errors import AnsibleError -from ansible.executor.task_executor import remove_omit from ansible.module_utils.common.text.converters import to_text from ansible.playbook.handler import Handler from ansible.playbook.task_include import TaskInclude from ansible.playbook.role_include import IncludeRole -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.utils.display import Display display = Display() @@ -114,7 +113,7 @@ class IncludedFile: if loader.get_basedir() not in task_vars['ansible_search_path']: task_vars['ansible_search_path'].append(loader.get_basedir()) - templar = Templar(loader=loader, variables=task_vars) + templar = TemplateEngine(loader=loader, variables=task_vars) if original_task.action in C._ACTION_INCLUDE_TASKS: include_file = None @@ -132,7 +131,7 @@ class IncludedFile: parent_include_dir = parent_include._role_path else: try: - parent_include_dir = os.path.dirname(templar.template(parent_include.args.get('_raw_params'))) + parent_include_dir = os.path.dirname(parent_include.args.get('_raw_params')) except AnsibleError as e: parent_include_dir = '' display.warning( @@ -144,7 +143,7 @@ class IncludedFile: cumulative_path = os.path.join(parent_include_dir, cumulative_path) else: cumulative_path = parent_include_dir - include_target = templar.template(include_result['include']) + include_target = include_result['include'] if original_task._role: dirname = 'handlers' if isinstance(original_task, Handler) else 'tasks' new_basedir = os.path.join(original_task._role._role_path, dirname, cumulative_path) @@ -170,7 +169,7 @@ class IncludedFile: if include_file is None: if original_task._role: - include_target = templar.template(include_result['include']) + include_target = include_result['include'] include_file = loader.path_dwim_relative( original_task._role._role_path, 'handlers' if isinstance(original_task, Handler) else 'tasks', @@ -179,25 +178,17 @@ class IncludedFile: else: include_file = loader.path_dwim(include_result['include']) - include_file = templar.template(include_file) inc_file = IncludedFile(include_file, include_args, special_vars, original_task) else: # template the included role's name here role_name = include_args.pop('name', include_args.pop('role', None)) - if role_name is not None: - role_name = templar.template(role_name) - new_task = original_task.copy() new_task.post_validate(templar=templar) new_task._role_name = role_name for from_arg in new_task.FROM_ARGS: if from_arg in include_args: from_key = from_arg.removesuffix('_from') - new_task._from_files[from_key] = templar.template(include_args.pop(from_arg)) - - omit_token = task_vars.get('omit') - if omit_token: - new_task._from_files = remove_omit(new_task._from_files, omit_token) + new_task._from_files[from_key] = include_args.pop(from_arg) inc_file = IncludedFile(role_name, include_args, special_vars, new_task, is_role=True) diff --git a/lib/ansible/playbook/play.py b/lib/ansible/playbook/play.py index 831e0280214..461a0a39258 100644 --- a/lib/ansible/playbook/play.py +++ b/lib/ansible/playbook/play.py @@ -19,8 +19,8 @@ from __future__ import annotations from ansible import constants as C from ansible import context -from ansible.errors import AnsibleParserError, AnsibleAssertionError, AnsibleError -from ansible.module_utils.common.text.converters import to_native +from ansible.errors import AnsibleError +from ansible.errors import AnsibleParserError, AnsibleAssertionError from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.six import binary_type, string_types, text_type from ansible.playbook.attribute import NonInheritableFieldAttribute @@ -31,6 +31,7 @@ from ansible.playbook.helpers import load_list_of_blocks, load_list_of_roles from ansible.playbook.role import Role from ansible.playbook.task import Task from ansible.playbook.taggable import Taggable +from ansible.parsing.vault import EncryptedString from ansible.utils.display import Display display = Display() @@ -122,7 +123,7 @@ class Play(Base, Taggable, CollectionSearch): elif not isinstance(entry, (binary_type, text_type)): raise AnsibleParserError("Hosts list contains an invalid host value: '{host!s}'".format(host=entry)) - elif not isinstance(value, (binary_type, text_type)): + elif not isinstance(value, (binary_type, text_type, EncryptedString)): raise AnsibleParserError("Hosts list must be a sequence or string. Please check your playbook.") def get_name(self): @@ -167,6 +168,8 @@ class Play(Base, Taggable, CollectionSearch): return super(Play, self).preprocess_data(ds) + # DTFIX-FUTURE: these do nothing but augment the exception message; DRY and nuke + def _load_tasks(self, attr, ds): """ Loads a list of blocks from a list which may be mixed tasks/blocks. @@ -174,8 +177,8 @@ class Play(Base, Taggable, CollectionSearch): """ try: return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading tasks: %s" % to_native(e), obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading tasks.", obj=self._ds) from ex def _load_pre_tasks(self, attr, ds): """ @@ -184,8 +187,8 @@ class Play(Base, Taggable, CollectionSearch): """ try: return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading pre_tasks", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading pre_tasks.", obj=self._ds) from ex def _load_post_tasks(self, attr, ds): """ @@ -194,8 +197,8 @@ class Play(Base, Taggable, CollectionSearch): """ try: return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading post_tasks", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading post_tasks.", obj=self._ds) from ex def _load_handlers(self, attr, ds): """ @@ -208,8 +211,8 @@ class Play(Base, Taggable, CollectionSearch): load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader), prepend=True ) - except AssertionError as e: - raise AnsibleParserError("A malformed block was encountered while loading handlers", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed block was encountered while loading handlers.", obj=self._ds) from ex def _load_roles(self, attr, ds): """ @@ -223,8 +226,8 @@ class Play(Base, Taggable, CollectionSearch): try: role_includes = load_list_of_roles(ds, play=self, variable_manager=self._variable_manager, loader=self._loader, collection_search_list=self.collections) - except AssertionError as e: - raise AnsibleParserError("A malformed role declaration was encountered.", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed role declaration was encountered.", obj=self._ds) from ex roles = [] for ri in role_includes: diff --git a/lib/ansible/playbook/play_context.py b/lib/ansible/playbook/play_context.py index e384ce0fb2f..699331626d8 100644 --- a/lib/ansible/playbook/play_context.py +++ b/lib/ansible/playbook/play_context.py @@ -88,7 +88,7 @@ class PlayContext(Base): # networking modules network_os = FieldAttribute(isa='string') - # docker FIXME: remove these + # FIXME: docker - remove these docker_extra_args = FieldAttribute(isa='string') # ??? @@ -103,10 +103,6 @@ class PlayContext(Base): become_flags = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_FLAGS) prompt = FieldAttribute(isa='string') - # general flags - only_tags = FieldAttribute(isa='set', default=set) - skip_tags = FieldAttribute(isa='set', default=set) - start_at_task = FieldAttribute(isa='string') step = FieldAttribute(isa='bool', default=False) @@ -201,8 +197,7 @@ class PlayContext(Base): # In the case of a loop, the delegated_to host may have been # templated based on the loop variable, so we try and locate # the host name in the delegated variable dictionary here - delegated_host_name = templar.template(task.delegate_to) - delegated_vars = variables.get('ansible_delegated_vars', dict()).get(delegated_host_name, dict()) + delegated_vars = variables.get('ansible_delegated_vars', dict()).get(task.delegate_to, dict()) delegated_transport = C.DEFAULT_TRANSPORT for transport_var in C.MAGIC_VARIABLE_MAPPING.get('connection'): @@ -218,8 +213,8 @@ class PlayContext(Base): if address_var in delegated_vars: break else: - display.debug("no remote address found for delegated host %s\nusing its name, so success depends on DNS resolution" % delegated_host_name) - delegated_vars['ansible_host'] = delegated_host_name + display.debug("no remote address found for delegated host %s\nusing its name, so success depends on DNS resolution" % task.delegate_to) + delegated_vars['ansible_host'] = task.delegate_to # reset the port back to the default if none was specified, to prevent # the delegated host from inheriting the original host's setting diff --git a/lib/ansible/playbook/playbook_include.py b/lib/ansible/playbook/playbook_include.py index 8e7c6c05082..e7fdad0e7df 100644 --- a/lib/ansible/playbook/playbook_include.py +++ b/lib/ansible/playbook/playbook_include.py @@ -22,16 +22,16 @@ import os import ansible.constants as C from ansible.errors import AnsibleParserError, AnsibleAssertionError from ansible.module_utils.common.text.converters import to_bytes +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils.six import string_types from ansible.parsing.splitter import split_args -from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base from ansible.playbook.conditional import Conditional from ansible.playbook.taggable import Taggable from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader._collection_finder import _get_collection_name_from_path, _get_collection_playbook_path -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.utils.display import Display display = Display() @@ -65,7 +65,7 @@ class PlaybookInclude(Base, Conditional, Taggable): if variable_manager: all_vars |= variable_manager.get_vars() - templar = Templar(loader=loader, variables=all_vars) + templar = TemplateEngine(loader=loader, variables=all_vars) # then we use the object to load a Playbook pb = Playbook(loader=loader) @@ -130,11 +130,9 @@ class PlaybookInclude(Base, Conditional, Taggable): if not isinstance(ds, dict): raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds))) - # the new, cleaned datastructure, which will have legacy - # items reduced to a standard structure - new_ds = AnsibleMapping() - if isinstance(ds, AnsibleBaseYAMLObject): - new_ds.ansible_pos = ds.ansible_pos + # the new, cleaned datastructure, which will have legacy items reduced to a standard structure suitable for the + # attributes of the task class; copy any tagged data to preserve things like origin + new_ds = AnsibleTagHelper.tag_copy(ds, {}) for (k, v) in ds.items(): if k in C._ACTION_IMPORT_PLAYBOOK: @@ -166,4 +164,5 @@ class PlaybookInclude(Base, Conditional, Taggable): if len(items) == 0: raise AnsibleParserError("import_playbook statements must specify the file name to import", obj=ds) - new_ds['import_playbook'] = items[0].strip() + # DTFIX-RELEASE: investigate this as a possible "problematic strip" + new_ds['import_playbook'] = AnsibleTagHelper.tag_copy(v, items[0].strip()) diff --git a/lib/ansible/playbook/role/__init__.py b/lib/ansible/playbook/role/__init__.py index 0887a77d7ab..1a7e882e051 100644 --- a/lib/ansible/playbook/role/__init__.py +++ b/lib/ansible/playbook/role/__init__.py @@ -27,7 +27,6 @@ from ansible.errors import AnsibleError, AnsibleParserError, AnsibleAssertionErr from ansible.module_utils.common.sentinel import Sentinel from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.six import binary_type, text_type -from ansible.playbook.attribute import FieldAttribute from ansible.playbook.base import Base from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.conditional import Conditional @@ -200,9 +199,9 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): return r - except RuntimeError: + except RecursionError as ex: raise AnsibleError("A recursion loop was detected with the roles specified. Make sure child roles do not have dependencies on parent roles", - obj=role_include._ds) + obj=role_include._ds) from ex def _load_role_data(self, role_include, parent_role=None): self._role_name = role_include.role @@ -274,18 +273,17 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): if task_data: try: self._task_blocks = load_list_of_blocks(task_data, play=self._play, role=self, loader=self._loader, variable_manager=self._variable_manager) - except AssertionError as e: - raise AnsibleParserError("The tasks/main.yml file for role '%s' must contain a list of tasks" % self._role_name, - obj=task_data, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError(f"The tasks/main.yml file for role {self._role_name!r} must contain a list of tasks.", obj=task_data) from ex handler_data = self._load_role_yaml('handlers', main=self._from_files.get('handlers')) if handler_data: try: self._handler_blocks = load_list_of_blocks(handler_data, play=self._play, role=self, use_handlers=True, loader=self._loader, variable_manager=self._variable_manager) - except AssertionError as e: - raise AnsibleParserError("The handlers/main.yml file for role '%s' must contain a list of tasks" % self._role_name, - obj=handler_data, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError(f"The handlers/main.yml file for role {self._role_name!r} must contain a list of tasks.", + obj=handler_data) from ex def _get_role_argspecs(self): """Get the role argument spec data. @@ -412,7 +410,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): raise AnsibleParserError("Failed loading '%s' for role (%s) as it is not inside the expected role path: '%s'" % (to_text(found), self._role_name, to_text(file_path))) - new_data = self._loader.load_from_file(found) + new_data = self._loader.load_from_file(found, trusted_as_template=True) if new_data: if data is not None and isinstance(new_data, Mapping): data = combine_vars(data, new_data) diff --git a/lib/ansible/playbook/role/definition.py b/lib/ansible/playbook/role/definition.py index 50758869b3b..670a4e101ca 100644 --- a/lib/ansible/playbook/role/definition.py +++ b/lib/ansible/playbook/role/definition.py @@ -21,14 +21,14 @@ import os from ansible import constants as C from ansible.errors import AnsibleError, AnsibleAssertionError +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils.six import string_types -from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.conditional import Conditional from ansible.playbook.taggable import Taggable -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine from ansible.utils.collection_loader import AnsibleCollectionRef from ansible.utils.collection_loader._collection_finder import _get_collection_role_path from ansible.utils.path import unfrackpath @@ -70,7 +70,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): if isinstance(ds, int): ds = "%s" % ds - if not isinstance(ds, dict) and not isinstance(ds, string_types) and not isinstance(ds, AnsibleBaseYAMLObject): + if not isinstance(ds, dict) and not isinstance(ds, string_types): raise AnsibleAssertionError() if isinstance(ds, dict): @@ -79,12 +79,9 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): # save the original ds for use later self._ds = ds - # we create a new data structure here, using the same - # object used internally by the YAML parsing code so we - # can preserve file:line:column information if it exists - new_ds = AnsibleMapping() - if isinstance(ds, AnsibleBaseYAMLObject): - new_ds.ansible_pos = ds.ansible_pos + # the new, cleaned datastructure, which will have legacy items reduced to a standard structure suitable for the + # attributes of the task class; copy any tagged data to preserve things like origin + new_ds = AnsibleTagHelper.tag_copy(ds, {}) # first we pull the role name out of the data structure, # and then use that to determine the role path (which may @@ -127,7 +124,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): # contains a variable, try and template it now if self._variable_manager: all_vars = self._variable_manager.get_vars(play=self._play) - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) role_name = templar.template(role_name) return role_name @@ -147,7 +144,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): else: all_vars = dict() - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) role_name = templar.template(role_name) role_tuple = None @@ -198,6 +195,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): return (role_name, role_path) searches = (self._collection_list or []) + role_search_paths + raise AnsibleError("the role '%s' was not found in %s" % (role_name, ":".join(searches)), obj=self._ds) def _split_role_params(self, ds): diff --git a/lib/ansible/playbook/role/include.py b/lib/ansible/playbook/role/include.py index 934b53ce9b4..3ab3d153a39 100644 --- a/lib/ansible/playbook/role/include.py +++ b/lib/ansible/playbook/role/include.py @@ -19,10 +19,8 @@ from __future__ import annotations from ansible.errors import AnsibleError, AnsibleParserError from ansible.module_utils.six import string_types -from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject from ansible.playbook.delegatable import Delegatable from ansible.playbook.role.definition import RoleDefinition -from ansible.module_utils.common.text.converters import to_native __all__ = ['RoleInclude'] @@ -42,8 +40,8 @@ class RoleInclude(RoleDefinition, Delegatable): @staticmethod def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None): - if not (isinstance(data, string_types) or isinstance(data, dict) or isinstance(data, AnsibleBaseYAMLObject)): - raise AnsibleParserError("Invalid role definition: %s" % to_native(data)) + if not (isinstance(data, string_types) or isinstance(data, dict)): + raise AnsibleParserError("Invalid role definition.", obj=data) if isinstance(data, string_types) and ',' in data: raise AnsibleError("Invalid old style role requirement: %s" % data) diff --git a/lib/ansible/playbook/role/metadata.py b/lib/ansible/playbook/role/metadata.py index 6606d862c9f..0125ae2e084 100644 --- a/lib/ansible/playbook/role/metadata.py +++ b/lib/ansible/playbook/role/metadata.py @@ -20,7 +20,6 @@ from __future__ import annotations import os from ansible.errors import AnsibleParserError, AnsibleError -from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.six import string_types from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base @@ -80,8 +79,8 @@ class RoleMetadata(Base, CollectionSearch): if def_parsed.get('name'): role_def['name'] = def_parsed['name'] roles.append(role_def) - except AnsibleError as exc: - raise AnsibleParserError(to_native(exc), obj=role_def, orig_exc=exc) + except AnsibleError as ex: + raise AnsibleParserError("Error parsing role dependencies.", obj=role_def) from ex current_role_path = None collection_search_list = None @@ -105,8 +104,8 @@ class RoleMetadata(Base, CollectionSearch): return load_list_of_roles(roles, play=self._owner._play, current_role_path=current_role_path, variable_manager=self._variable_manager, loader=self._loader, collection_search_list=collection_search_list) - except AssertionError as e: - raise AnsibleParserError("A malformed list of role dependencies was encountered.", obj=self._ds, orig_exc=e) + except AssertionError as ex: + raise AnsibleParserError("A malformed list of role dependencies was encountered.", obj=self._ds) from ex def serialize(self): return dict( diff --git a/lib/ansible/playbook/role_include.py b/lib/ansible/playbook/role_include.py index 1894d6df8f9..48003db7dff 100644 --- a/lib/ansible/playbook/role_include.py +++ b/lib/ansible/playbook/role_include.py @@ -24,7 +24,7 @@ from ansible.playbook.role import Role from ansible.playbook.role.include import RoleInclude from ansible.utils.display import Display from ansible.module_utils.six import string_types -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine __all__ = ['IncludeRole'] @@ -79,7 +79,7 @@ class IncludeRole(TaskInclude): available_variables = variable_manager.get_vars(play=myplay, task=self) else: available_variables = {} - templar = Templar(loader=loader, variables=available_variables) + templar = TemplateEngine(loader=loader, variables=available_variables) from_files = templar.template(self._from_files) # build role diff --git a/lib/ansible/playbook/taggable.py b/lib/ansible/playbook/taggable.py index 79810a41eaf..163e3380018 100644 --- a/lib/ansible/playbook/taggable.py +++ b/lib/ansible/playbook/taggable.py @@ -20,8 +20,9 @@ from __future__ import annotations from ansible.errors import AnsibleError from ansible.module_utils.six import string_types from ansible.module_utils.common.sentinel import Sentinel +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.playbook.attribute import FieldAttribute -from ansible.template import Templar +from ansible._internal._templating._engine import TemplateEngine def _flatten_tags(tags: list) -> list: @@ -42,16 +43,20 @@ class Taggable: def _load_tags(self, attr, ds): if isinstance(ds, list): return ds - elif isinstance(ds, string_types): - return [x.strip() for x in ds.split(',')] - else: - raise AnsibleError('tags must be specified as a list', obj=ds) + + if isinstance(ds, str): + # DTFIX-RELEASE: this allows each individual tag to be templated, but prevents the use of commas in templates, is that what we want? + # DTFIX-RELEASE: this can return empty tags (including a list of nothing but empty tags), is that correct? + # DTFIX-RELEASE: the original code seemed to attempt to preserve `ds` if there were no commas, but it never ran, what should it actually do? + return [AnsibleTagHelper.tag_copy(ds, item.strip()) for item in ds.split(',')] + + raise AnsibleError('tags must be specified as a list', obj=ds) def evaluate_tags(self, only_tags, skip_tags, all_vars): """ this checks if the current item should be executed depending on tag options """ if self.tags: - templar = Templar(loader=self._loader, variables=all_vars) + templar = TemplateEngine(loader=self._loader, variables=all_vars) obj = self while obj is not None: if (_tags := getattr(obj, "_tags", Sentinel)) is not Sentinel: diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py index 3f43bfbe7ca..6579922624e 100644 --- a/lib/ansible/playbook/task.py +++ b/lib/ansible/playbook/task.py @@ -17,14 +17,18 @@ from __future__ import annotations +import typing as t + from ansible import constants as C -from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError from ansible.module_utils.common.sentinel import Sentinel +from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError, AnsibleValueOmittedError +from ansible.executor.module_common import _get_action_arg_defaults from ansible.module_utils.common.text.converters import to_native +from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils.six import string_types -from ansible.parsing.mod_args import ModuleArgsParser -from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping -from ansible.plugins.loader import lookup_loader +from ansible.parsing.mod_args import ModuleArgsParser, RAW_PARAM_MODULES +from ansible.plugins.action import ActionBase +from ansible.plugins.loader import action_loader, module_loader, lookup_loader from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base from ansible.playbook.block import Block @@ -35,10 +39,14 @@ from ansible.playbook.loop_control import LoopControl from ansible.playbook.notifiable import Notifiable from ansible.playbook.role import Role from ansible.playbook.taggable import Taggable +from ansible._internal import _task +from ansible._internal._templating import _marker_behaviors +from ansible._internal._templating._jinja_bits import is_possibly_all_template +from ansible._internal._templating._engine import TemplateEngine, TemplateOptions from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.display import Display -from ansible.utils.vars import isidentifier +from ansible.utils.vars import validate_variable_name __all__ = ['Task'] @@ -68,8 +76,8 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl # inheritance is only triggered if the 'current value' is Sentinel, # default can be set at play/top level object and inheritance will take it's course. - args = NonInheritableFieldAttribute(isa='dict', default=dict) - action = NonInheritableFieldAttribute(isa='string') + args = t.cast(dict, NonInheritableFieldAttribute(isa='dict', default=dict)) + action = t.cast(str, NonInheritableFieldAttribute(isa='string')) async_val = NonInheritableFieldAttribute(isa='int', default=0, alias='async') changed_when = NonInheritableFieldAttribute(isa='list', default=list) @@ -85,13 +93,13 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl # deprecated, used to be loop and loop_args but loop has been repurposed loop_with = NonInheritableFieldAttribute(isa='string', private=True) - def __init__(self, block=None, role=None, task_include=None): + def __init__(self, block=None, role=None, task_include=None) -> None: """ constructors a task, without the Task.load classmethod, it will be pretty blank """ self._role = role self._parent = None self.implicit = False - self.resolved_action = None + self.resolved_action: str | None = None if task_include: self._parent = task_include @@ -132,13 +140,80 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl @staticmethod def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None): - t = Task(block=block, role=role, task_include=task_include) - return t.load_data(data, variable_manager=variable_manager, loader=loader) + task = Task(block=block, role=role, task_include=task_include) + return task.load_data(data, variable_manager=variable_manager, loader=loader) + + def _post_validate_module_defaults(self, attr: str, value: t.Any, templar: TemplateEngine) -> t.Any: + """Override module_defaults post validation to disable templating, which is handled by args post validation.""" + return value + + def _post_validate_args(self, attr: str, value: t.Any, templar: TemplateEngine) -> dict[str, t.Any]: + try: + self.action = templar.template(self.action) + except AnsibleValueOmittedError: + # some strategies may trigger this error when templating task.action, but backstop here if not + raise AnsibleParserError("Omit is not valid for the `action` keyword.", obj=self.action) from None + + action_context = action_loader.get_with_context(self.action, collection_list=self.collections, class_only=True) + + if not action_context.plugin_load_context.resolved: + module_or_action_context = module_loader.find_plugin_with_context(self.action, collection_list=self.collections) + + if not module_or_action_context.resolved: + raise AnsibleError(f"Cannot resolve {self.action!r} to an action or module.", obj=self.action) + + action_context = action_loader.get_with_context('ansible.legacy.normal', collection_list=self.collections, class_only=True) + else: + module_or_action_context = action_context.plugin_load_context + + self.resolved_action = module_or_action_context.resolved_fqcn + + action_type: type[ActionBase] = action_context.object + + vp = value.pop('_variable_params', None) + + supports_raw_params = action_type.supports_raw_params or module_or_action_context.resolved_fqcn in RAW_PARAM_MODULES + + if supports_raw_params: + raw_params_to_finalize = None + else: + raw_params_to_finalize = value.pop('_raw_params', None) # always str or None + + # TaskArgsFinalizer performs more thorough type checking, but this provides a friendlier error message for a subset of detected cases. + if raw_params_to_finalize and not is_possibly_all_template(raw_params_to_finalize): + raise AnsibleError(f'Action {module_or_action_context.resolved_fqcn!r} does not support raw params.', obj=self.action) + + args_finalizer = _task.TaskArgsFinalizer( + _get_action_arg_defaults(module_or_action_context.resolved_fqcn, self, templar), + vp, + raw_params_to_finalize, + value, + templar=templar, + ) + + try: + with action_type.get_finalize_task_args_context() as finalize_context: + args = args_finalizer.finalize(action_type.finalize_task_arg, context=finalize_context) + except Exception as ex: + raise AnsibleError(f'Finalization of task args for {module_or_action_context.resolved_fqcn!r} failed.', obj=self.action) from ex + + if self._origin: + args = self._origin.tag(args) + + return args + + def _get_meta(self) -> str | None: + # FUTURE: validate meta and return an enum instead of a str + # meta currently does not support being templated, so we can cheat + if self.action in C._ACTION_META: + return self.args.get('_raw_params') + + return None def __repr__(self): """ returns a human-readable representation of the task """ - if self.action in C._ACTION_META: - return "TASK: meta (%s)" % self.args['_raw_params'] + if meta := self._get_meta(): + return f"TASK: meta ({meta})" else: return "TASK: %s" % self.get_name() @@ -164,12 +239,9 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl if not isinstance(ds, dict): raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds))) - # the new, cleaned datastructure, which will have legacy - # items reduced to a standard structure suitable for the - # attributes of the task class - new_ds = AnsibleMapping() - if isinstance(ds, AnsibleBaseYAMLObject): - new_ds.ansible_pos = ds.ansible_pos + # the new, cleaned datastructure, which will have legacy items reduced to a standard structure suitable for the + # attributes of the task class; copy any tagged data to preserve things like origin + new_ds = AnsibleTagHelper.tag_copy(ds, {}) # since this affects the task action parsing, we have to resolve in preprocess instead of in typical validator default_collection = AnsibleCollectionConfig.default_collection @@ -202,26 +274,13 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl args_parser = ModuleArgsParser(task_ds=ds, collection_list=collections_list) try: (action, args, delegate_to) = args_parser.parse() - except AnsibleParserError as e: + except AnsibleParserError as ex: # if the raises exception was created with obj=ds args, then it includes the detail # so we dont need to add it so we can just re raise. - if e.obj: + if ex.obj: raise # But if it wasn't, we can add the yaml object now to get more detail - raise AnsibleParserError(to_native(e), obj=ds, orig_exc=e) - else: - # Set the resolved action plugin (or if it does not exist, module) for callbacks. - self.resolved_action = args_parser.resolved_action - - # the command/shell/script modules used to support the `cmd` arg, - # which corresponds to what we now call _raw_params, so move that - # value over to _raw_params (assuming it is empty) - if action in C._ACTION_HAS_CMD: - if 'cmd' in args: - if args.get('_raw_params', '') != '': - raise AnsibleError("The 'cmd' argument cannot be used when other raw parameters are specified." - " Please put everything in one or the other place.", obj=ds) - args['_raw_params'] = args.pop('cmd') + raise AnsibleParserError("Error parsing task arguments.", obj=ds) from ex new_ds['action'] = action new_ds['args'] = args @@ -277,8 +336,11 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl setattr(self, name, [value]) def _validate_register(self, attr, name, value): - if value is not None and not isidentifier(value): - raise AnsibleParserError(f"Invalid variable name in 'register' specified: '{value}'") + if value is not None: + try: + validate_variable_name(value) + except Exception as ex: + raise AnsibleParserError("Invalid 'register' specified.", obj=value) from ex def post_validate(self, templar): """ @@ -289,9 +351,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl if self._parent: self._parent.post_validate(templar) - if AnsibleCollectionConfig.default_collection: - pass - super(Task, self).post_validate(templar) def _post_validate_loop(self, attr, value, templar): @@ -301,44 +360,53 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl """ return value + def _post_validate_name(self, attr, value, templar): + """ + Override post-validation behavior for `name` to be best-effort for the vars available. + Direct access via `post_validate_attribute` writes the value back to provide a stable value. + This value is individually post-validated early by strategies for the benefit of callbacks. + """ + with _marker_behaviors.ReplacingMarkerBehavior.warning_context() as replacing_behavior: + self.name = templar.extend(marker_behavior=replacing_behavior).template(value, options=TemplateOptions(value_for_omit=None)) + + return self.name + def _post_validate_environment(self, attr, value, templar): """ Override post validation of vars on the play, as we don't want to template these too early. """ env = {} - if value is not None: - def _parse_env_kv(k, v): - try: - env[k] = templar.template(v, convert_bare=False) - except AnsibleUndefinedVariable as e: - error = to_native(e) - if self.action in C._ACTION_FACT_GATHERING and 'ansible_facts.env' in error or 'ansible_env' in error: - # ignore as fact gathering is required for 'env' facts - return - raise - - if isinstance(value, list): - for env_item in value: - if isinstance(env_item, dict): - for k in env_item: - _parse_env_kv(k, env_item[k]) - else: - isdict = templar.template(env_item, convert_bare=False) - if isinstance(isdict, dict): - env |= isdict - else: - display.warning("could not parse environment value, skipping: %s" % value) - - elif isinstance(value, dict): - # should not really happen - env = dict() - for env_item in value: - _parse_env_kv(env_item, value[env_item]) + # FUTURE: kill this with fire + def _parse_env_kv(k, v): + try: + env[k] = templar.template(v) + except AnsibleValueOmittedError: + # skip this value + return + except AnsibleUndefinedVariable as e: + error = to_native(e) + if self.action in C._ACTION_FACT_GATHERING and 'ansible_facts.env' in error or 'ansible_env' in error: + # ignore as fact gathering is required for 'env' facts + return + raise + + # NB: the environment FieldAttribute definition ensures that value is always a list + for env_item in value: + if isinstance(env_item, dict): + for k in env_item: + _parse_env_kv(k, env_item[k]) else: - # at this point it should be a simple string, also should not happen - env = templar.template(value, convert_bare=False) + try: + isdict = templar.template(env_item) + except AnsibleValueOmittedError: + continue + + if isinstance(isdict, dict): + env |= isdict + else: + display.warning("could not parse environment value, skipping: %s" % value) return env @@ -385,7 +453,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl all_vars |= self.vars return all_vars - def copy(self, exclude_parent=False, exclude_tasks=False): + def copy(self, exclude_parent: bool = False, exclude_tasks: bool = False) -> Task: new_me = super(Task, self).copy() new_me._parent = None @@ -519,3 +587,28 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl while not isinstance(parent, Block): parent = parent._parent return parent._play + + def dump_attrs(self): + """Override to smuggle important non-FieldAttribute values back to the controller.""" + attrs = super().dump_attrs() + attrs.update(resolved_action=self.resolved_action) + return attrs + + def _resolve_conditional( + self, + conditional: list[str | bool], + variables: dict[str, t.Any], + *, + result_context: dict[str, t.Any] | None = None, + ) -> bool: + """Loops through the conditionals set on this object, returning False if any of them evaluate as such, as well as the condition that was False.""" + engine = TemplateEngine(self._loader, variables=variables) + + for item in conditional: + if not engine.evaluate_conditional(item): + if result_context is not None: + result_context.update(false_condition=item) + + return False + + return True diff --git a/lib/ansible/plugins/__init__.py b/lib/ansible/plugins/__init__.py index 44112597aa7..5e597da2f9e 100644 --- a/lib/ansible/plugins/__init__.py +++ b/lib/ansible/plugins/__init__.py @@ -19,17 +19,16 @@ from __future__ import annotations -from abc import ABC - +import abc import types import typing as t from ansible import constants as C from ansible.errors import AnsibleError -from ansible.module_utils.common.text.converters import to_native -from ansible.module_utils.six import string_types from ansible.utils.display import Display +from ansible.module_utils._internal import _plugin_exec_context + display = Display() if t.TYPE_CHECKING: @@ -42,13 +41,32 @@ PLUGIN_PATH_CACHE = {} # type: dict[str, dict[str, dict[str, PluginPathContext] def get_plugin_class(obj): - if isinstance(obj, string_types): + if isinstance(obj, str): return obj.lower().replace('module', '') else: return obj.__class__.__name__.lower().replace('module', '') -class AnsiblePlugin(ABC): +class _ConfigurablePlugin(t.Protocol): + """Protocol to provide type-safe access to config for plugin-related mixins.""" + + def get_option(self, option: str, hostvars: dict[str, object] | None = None) -> object: ... + + +class _AnsiblePluginInfoMixin(_plugin_exec_context.HasPluginInfo): + """Mixin to provide type annotations and default values for existing PluginLoader-set load-time attrs.""" + _original_path: str | None = None + _load_name: str | None = None + _redirected_names: list[str] | None = None + ansible_aliases: list[str] | None = None + ansible_name: str | None = None + + @property + def plugin_type(self) -> str: + return self.__class__.__name__.lower().replace('module', '') + + +class AnsiblePlugin(_AnsiblePluginInfoMixin, _ConfigurablePlugin, metaclass=abc.ABCMeta): # Set by plugin loader _load_name: str @@ -81,7 +99,7 @@ class AnsiblePlugin(ABC): try: option_value, origin = C.config.get_config_value_and_origin(option, plugin_type=self.plugin_type, plugin_name=self._load_name, variables=hostvars) except AnsibleError as e: - raise KeyError(to_native(e)) + raise KeyError(str(e)) return option_value, origin def get_option(self, option, hostvars=None): @@ -123,10 +141,6 @@ class AnsiblePlugin(ABC): self.set_options() return option in self._options - @property - def plugin_type(self): - return self.__class__.__name__.lower().replace('module', '') - @property def option_definitions(self): if (not hasattr(self, "_defs")) or self._defs is None: @@ -137,23 +151,56 @@ class AnsiblePlugin(ABC): # FIXME: standardize required check based on config pass + def __repr__(self): + ansible_name = getattr(self, 'ansible_name', '(unknown)') + load_name = getattr(self, '_load_name', '(unknown)') + return f'{type(self).__name__}(plugin_type={self.plugin_type!r}, {ansible_name=!r}, {load_name=!r})' -class AnsibleJinja2Plugin(AnsiblePlugin): - - def __init__(self, function): +class AnsibleJinja2Plugin(AnsiblePlugin, metaclass=abc.ABCMeta): + def __init__(self, function: t.Callable) -> None: super(AnsibleJinja2Plugin, self).__init__() self._function = function + # Declare support for markers. Plugins with `False` here will never be invoked with markers for top-level arguments. + self.accept_args_markers = getattr(self._function, 'accept_args_markers', False) + self.accept_lazy_markers = getattr(self._function, 'accept_lazy_markers', False) + @property - def plugin_type(self): - return self.__class__.__name__.lower().replace('ansiblejinja2', '') + @abc.abstractmethod + def plugin_type(self) -> str: + ... - def _no_options(self, *args, **kwargs): + def _no_options(self, *args, **kwargs) -> t.NoReturn: raise NotImplementedError() has_option = get_option = get_options = option_definitions = set_option = set_options = _no_options @property - def j2_function(self): + def j2_function(self) -> t.Callable: return self._function + + +_TCallable = t.TypeVar('_TCallable', bound=t.Callable) + + +def accept_args_markers(plugin: _TCallable) -> _TCallable: + """ + A decorator to mark a Jinja plugin as capable of handling `Marker` values for its top-level arguments. + Non-decorated plugin invocation is skipped when a top-level argument is a `Marker`, with the first such value substituted as the plugin result. + This ensures that only plugins which understand `Marker` instances for top-level arguments will encounter them. + """ + plugin.accept_args_markers = True + + return plugin + + +def accept_lazy_markers(plugin: _TCallable) -> _TCallable: + """ + A decorator to mark a Jinja plugin as capable of handling `Marker` values retrieved from lazy containers. + Non-decorated plugins will trigger a `MarkerError` exception when attempting to retrieve a `Marker` from a lazy container. + This ensures that only plugins which understand lazy retrieval of `Marker` instances will encounter them. + """ + plugin.accept_lazy_markers = True + + return plugin diff --git a/lib/ansible/plugins/action/__init__.py b/lib/ansible/plugins/action/__init__.py index a4ff8a37385..64a16775e54 100644 --- a/lib/ansible/plugins/action/__init__.py +++ b/lib/ansible/plugins/action/__init__.py @@ -6,6 +6,7 @@ from __future__ import annotations import base64 +import contextlib import json import os import re @@ -13,29 +14,44 @@ import secrets import shlex import stat import tempfile +import typing as t from abc import ABC, abstractmethod from collections.abc import Sequence from ansible import constants as C +from ansible._internal._errors import _captured from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleActionSkip, AnsibleActionFail, AnsibleAuthenticationFailure -from ansible.executor.module_common import modify_module +from ansible._internal._errors import _utils +from ansible.executor.module_common import modify_module, _BuiltModule from ansible.executor.interpreter_discovery import discover_interpreter, InterpreterDiscoveryRequiredError +from ansible.module_utils._internal import _traceback from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.errors import UnsupportedError from ansible.module_utils.json_utils import _filter_non_json_lines +from ansible.module_utils.common.json import Direction, get_module_encoder, get_module_decoder from ansible.module_utils.six import binary_type, string_types, text_type from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text -from ansible.parsing.utils.jsonify import jsonify from ansible.release import __version__ from ansible.utils.collection_loader import resource_from_fqcr from ansible.utils.display import Display -from ansible.utils.unsafe_proxy import wrap_var, AnsibleUnsafeText from ansible.vars.clean import remove_internal_keys from ansible.utils.plugin_docs import get_versioned_doclink +from ansible import _internal +from ansible._internal._templating import _engine + +from .. import _AnsiblePluginInfoMixin +from ...module_utils.common.messages import PluginInfo display = Display() +if t.TYPE_CHECKING: + from ansible.parsing.dataloader import DataLoader + from ansible.playbook.play_context import PlayContext + from ansible.playbook.task import Task + from ansible.plugins.connection import ConnectionBase + from ansible.template import Templar + def _validate_utf8_json(d): if isinstance(d, text_type): @@ -49,8 +65,7 @@ def _validate_utf8_json(d): _validate_utf8_json(o) -class ActionBase(ABC): - +class ActionBase(ABC, _AnsiblePluginInfoMixin): """ This class is the base class for all action plugins, and defines code common to all actions. The base class handles the connection @@ -67,22 +82,24 @@ class ActionBase(ABC): _requires_connection = True _supports_check_mode = True _supports_async = False + supports_raw_params = False - def __init__(self, task, connection, play_context, loader, templar, shared_loader_obj): + def __init__(self, task: Task, connection: ConnectionBase, play_context: PlayContext, loader: DataLoader, templar: Templar, shared_loader_obj=None): self._task = task self._connection = connection self._play_context = play_context self._loader = loader self._templar = templar - self._shared_loader_obj = shared_loader_obj + + from ansible.plugins import loader as plugin_loaders # avoid circular global import since PluginLoader needs ActionBase + + self._shared_loader_obj = plugin_loaders # shared_loader_obj was just a ref to `ansible.plugins.loader` anyway; this lets us inherit its type self._cleanup_remote_tmp = False # interpreter discovery state - self._discovered_interpreter_key = None + self._discovered_interpreter_key: str | None = None self._discovered_interpreter = False - self._discovery_deprecation_warnings = [] - self._discovery_warnings = [] - self._used_interpreter = None + self._used_interpreter: str | None = None # Backwards compat: self._display isn't really needed, just import the global display and use that. self._display = display @@ -109,9 +126,9 @@ class ActionBase(ABC): result = {} if tmp is not None: - result['warning'] = ['ActionModule.run() no longer honors the tmp parameter. Action' - ' plugins should set self._connection._shell.tmpdir to share' - ' the tmpdir'] + display.warning('ActionModule.run() no longer honors the tmp parameter. Action' + ' plugins should set self._connection._shell.tmpdir to share' + ' the tmpdir.') del tmp if self._task.async_val and not self._supports_async: @@ -177,7 +194,7 @@ class ActionBase(ABC): if isinstance(error, UnsupportedError): msg = f"Unsupported parameters for ({self._load_name}) module: {msg}" - raise AnsibleActionFail(msg) + raise AnsibleActionFail(msg, obj=self._task.args) return validation_result, new_module_args @@ -193,6 +210,28 @@ class ActionBase(ABC): if force or not self._task.async_val: self._remove_tmp_path(self._connection._shell.tmpdir) + @classmethod + @contextlib.contextmanager + @_internal.experimental + def get_finalize_task_args_context(cls) -> t.Any: + """ + EXPERIMENTAL: Unstable API subject to change at any time without notice. + Wraps task arg finalization with (optional) stateful context. + The context manager is entered during `Task.post_validate_args, and may yield a single value to be passed + as `context` to Task.finalize_task_arg for each task arg. + """ + yield None + + @classmethod + @_internal.experimental + def finalize_task_arg(cls, name: str, value: t.Any, templar: _engine.TemplateEngine, context: t.Any) -> t.Any: + """ + EXPERIMENTAL: Unstable API subject to change at any time without notice. + Called for each task arg to allow for custom templating. + The optional `context` value is sourced from `Task.get_finalize_task_args_context`. + """ + return templar.template(value) + def get_plugin_option(self, plugin, option, default=None): """Helper to get an option from a plugin without having to use the try/except dance everywhere to set a default @@ -218,7 +257,7 @@ class ActionBase(ABC): return True return False - def _configure_module(self, module_name, module_args, task_vars): + def _configure_module(self, module_name, module_args, task_vars) -> tuple[_BuiltModule, str]: """ Handles the loading and templating of the module code through the modify_module() function. @@ -276,27 +315,37 @@ class ActionBase(ABC): raise AnsibleError("The module %s was not found in configured module paths" % (module_name)) # insert shared code and arguments into the module - final_environment = dict() + final_environment: dict[str, t.Any] = {} self._compute_environment_string(final_environment) + # `modify_module` adapts PluginInfo to allow target-side use of `PluginExecContext` since modules aren't plugins + plugin = PluginInfo( + requested_name=module_name, + resolved_name=result.resolved_fqcn, + type='module', + ) + # modify_module will exit early if interpreter discovery is required; re-run after if necessary - for dummy in (1, 2): + for _dummy in (1, 2): try: - (module_data, module_style, module_shebang) = modify_module(module_name, module_path, module_args, self._templar, - task_vars=use_vars, - module_compression=C.config.get_config_value('DEFAULT_MODULE_COMPRESSION', - variables=task_vars), - async_timeout=self._task.async_val, - environment=final_environment, - remote_is_local=bool(getattr(self._connection, '_remote_is_local', False)), - become_plugin=self._connection.become) + module_bits = modify_module( + module_name=module_name, + module_path=module_path, + module_args=module_args, + templar=self._templar, + task_vars=use_vars, + module_compression=C.config.get_config_value('DEFAULT_MODULE_COMPRESSION', variables=task_vars), + async_timeout=self._task.async_val, + environment=final_environment, + remote_is_local=bool(getattr(self._connection, '_remote_is_local', False)), + plugin=plugin, + become_plugin=self._connection.become, + ) + break except InterpreterDiscoveryRequiredError as idre: - self._discovered_interpreter = AnsibleUnsafeText(discover_interpreter( - action=self, - interpreter_name=idre.interpreter_name, - discovery_mode=idre.discovery_mode, - task_vars=use_vars)) + self._discovered_interpreter = discover_interpreter(action=self, interpreter_name=idre.interpreter_name, + discovery_mode=idre.discovery_mode, task_vars=use_vars) # update the local task_vars with the discovered interpreter (which might be None); # we'll propagate back to the controller in the task result @@ -316,7 +365,7 @@ class ActionBase(ABC): else: task_vars['ansible_delegated_vars'][self._task.delegate_to]['ansible_facts'][discovered_key] = self._discovered_interpreter - return (module_style, module_shebang, module_data, module_path) + return module_bits, module_path def _compute_environment_string(self, raw_environment_out=None): """ @@ -521,18 +570,19 @@ class ActionBase(ABC): self._connection.put_file(local_path, remote_path) return remote_path - def _transfer_data(self, remote_path, data): + def _transfer_data(self, remote_path: str | bytes, data: str | bytes) -> str | bytes: """ Copies the module data out to the temporary module path. """ - if isinstance(data, dict): - data = jsonify(data) + if isinstance(data, str): + data = data.encode(errors='surrogateescape') + elif not isinstance(data, bytes): + raise TypeError('data must be either a string or bytes') afd, afile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP) afo = os.fdopen(afd, 'wb') try: - data = to_bytes(data, errors='surrogate_or_strict') afo.write(data) except Exception as e: raise AnsibleError("failure writing module data to temporary file for transfer: %s" % to_native(e)) @@ -963,6 +1013,8 @@ class ActionBase(ABC): # allow user to insert string to add context to remote loggging module_args['_ansible_target_log_info'] = C.config.get_config_value('TARGET_LOG_INFO', variables=task_vars) + module_args['_ansible_tracebacks_for'] = _traceback.traceback_for() + def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=None, wrap_async=False, ignore_unknown_opts: bool = False): """ @@ -1009,7 +1061,8 @@ class ActionBase(ABC): self._task.environment.append({"ANSIBLE_ASYNC_DIR": async_dir}) # FUTURE: refactor this along with module build process to better encapsulate "smart wrapper" functionality - (module_style, shebang, module_data, module_path) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars) + module_bits, module_path = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars) + (module_style, shebang, module_data) = (module_bits.module_style, module_bits.shebang, module_bits.b_module_data) display.vvv("Using module file %s" % module_path) if not shebang and module_style != 'binary': raise AnsibleError("module (%s) is missing interpreter line" % module_name) @@ -1045,7 +1098,8 @@ class ActionBase(ABC): args_data += '%s=%s ' % (k, shlex.quote(text_type(v))) self._transfer_data(args_file_path, args_data) elif module_style in ('non_native_want_json', 'binary'): - self._transfer_data(args_file_path, json.dumps(module_args)) + profile_encoder = get_module_encoder(module_bits.serialization_profile, Direction.CONTROLLER_TO_MODULE) + self._transfer_data(args_file_path, json.dumps(module_args, cls=profile_encoder)) display.debug("done transferring module to remote") environment_string = self._compute_environment_string() @@ -1068,8 +1122,8 @@ class ActionBase(ABC): if wrap_async and not self._connection.always_pipeline_modules: # configure, upload, and chmod the async_wrapper module - (async_module_style, shebang, async_module_data, async_module_path) = self._configure_module( - module_name='ansible.legacy.async_wrapper', module_args=dict(), task_vars=task_vars) + (async_module_bits, async_module_path) = self._configure_module(module_name='ansible.legacy.async_wrapper', module_args=dict(), task_vars=task_vars) + (async_module_style, shebang, async_module_data) = (async_module_bits.module_style, async_module_bits.shebang, async_module_bits.b_module_data) async_module_remote_filename = self._connection._shell.get_remote_filename(async_module_path) remote_async_module_path = self._connection._shell.join_path(tmpdir, async_module_remote_filename) self._transfer_data(remote_async_module_path, async_module_data) @@ -1118,7 +1172,7 @@ class ActionBase(ABC): res = self._low_level_execute_command(cmd, sudoable=sudoable, in_data=in_data) # parse the main result - data = self._parse_returned_data(res) + data = self._parse_returned_data(res, module_bits.serialization_profile) # NOTE: INTERNAL KEYS ONLY ACCESSIBLE HERE # get internal info before cleaning @@ -1159,71 +1213,66 @@ class ActionBase(ABC): data['ansible_facts'][self._discovered_interpreter_key] = self._discovered_interpreter - if self._discovery_warnings: - if data.get('warnings') is None: - data['warnings'] = [] - data['warnings'].extend(self._discovery_warnings) - - if self._discovery_deprecation_warnings: - if data.get('deprecations') is None: - data['deprecations'] = [] - data['deprecations'].extend(self._discovery_deprecation_warnings) - - # mark the entire module results untrusted as a template right here, since the current action could - # possibly template one of these values. - data = wrap_var(data) - display.debug("done with _execute_module (%s, %s)" % (module_name, module_args)) return data - def _parse_returned_data(self, res): + def _parse_returned_data(self, res: dict[str, t.Any], profile: str) -> dict[str, t.Any]: try: - filtered_output, warnings = _filter_non_json_lines(res.get('stdout', u''), objects_only=True) + filtered_output, warnings = _filter_non_json_lines(res.get('stdout', ''), objects_only=True) + for w in warnings: display.warning(w) - data = json.loads(filtered_output) - - if C.MODULE_STRICT_UTF8_RESPONSE and not data.pop('_ansible_trusted_utf8', None): - try: - _validate_utf8_json(data) - except UnicodeEncodeError: - # When removing this, also remove the loop and latin-1 from ansible.module_utils.common.text.converters.jsonify - display.deprecated( - f'Module "{self._task.resolved_action or self._task.action}" returned non UTF-8 data in ' - 'the JSON response. This will become an error in the future', - version='2.18', - ) - - data['_ansible_parsed'] = True - except ValueError: - # not valid json, lets try to capture error - data = dict(failed=True, _ansible_parsed=False) - data['module_stdout'] = res.get('stdout', u'') - if 'stderr' in res: - data['module_stderr'] = res['stderr'] - if res['stderr'].startswith(u'Traceback'): - data['exception'] = res['stderr'] - - # in some cases a traceback will arrive on stdout instead of stderr, such as when using ssh with -tt - if 'exception' not in data and data['module_stdout'].startswith(u'Traceback'): - data['exception'] = data['module_stdout'] - - # The default - data['msg'] = "MODULE FAILURE" - - # try to figure out if we are missing interpreter + decoder = get_module_decoder(profile, Direction.MODULE_TO_CONTROLLER) + + data = json.loads(filtered_output, cls=decoder) + + _captured.AnsibleModuleCapturedError.normalize_result_exception(data) + + data.update(_ansible_parsed=True) # this must occur after normalize_result_exception, since it checks the type of data to ensure it's a dict + except ValueError as ex: + message = "Module result deserialization failed." + help_text = "" + include_cause_message = True + if self._used_interpreter is not None: - interpreter = re.escape(self._used_interpreter.lstrip('!#')) - match = re.compile('%s: (?:No such file or directory|not found)' % interpreter) - if match.search(data['module_stderr']) or match.search(data['module_stdout']): - data['msg'] = "The module failed to execute correctly, you probably need to set the interpreter." + interpreter = self._used_interpreter.lstrip('!#') + # "not found" case is currently not tested; it was once reproducible + # see: https://github.com/ansible/ansible/pull/53534 + not_found_err_re = re.compile(rf'{re.escape(interpreter)}: (?:No such file or directory|not found|command not found)') + + if not_found_err_re.search(res.get('stderr', '')) or not_found_err_re.search(res.get('stdout', '')): + message = f"The module interpreter {interpreter!r} was not found." + help_text = 'Consider overriding the configured interpreter path for this host. ' + include_cause_message = False # cause context *might* be useful in the traceback, but the JSON deserialization failure message is not + + try: + # Because the underlying action API is built on result dicts instead of exceptions (for all but the most catastrophic failures), + # we're using a tweaked version of the module exception handler to get new ErrorDetail-backed errors from this part of the code. + # Ideally this would raise immediately on failure, but this would likely break actions that assume `ActionBase._execute_module()` + # does not raise on module failure. + + error = AnsibleError( + message=message, + help_text=help_text + "See stdout/stderr for the returned output.", + ) + + error._include_cause_message = include_cause_message - # always append hint - data['msg'] += '\nSee stdout/stderr for the exact error' + raise error from ex + except AnsibleError as ansible_ex: + sentinel = object() + + data = self.result_dict_from_exception(ansible_ex) + data.update( + _ansible_parsed=False, + module_stdout=res.get('stdout', ''), + module_stderr=res.get('stderr', sentinel), + rc=res.get('rc', sentinel), + ) + + data = {k: v for k, v in data.items() if v is not sentinel} - if 'rc' in res: - data['rc'] = res['rc'] return data # FIXME: move to connection base @@ -1395,3 +1444,23 @@ class ActionBase(ABC): # if missing it will return a file not found exception return self._loader.path_dwim_relative_stack(path_stack, dirname, needle) + + @staticmethod + def result_dict_from_exception(exception: BaseException) -> dict[str, t.Any]: + """Return a failed task result dict from the given exception.""" + if ansible_remoted_error := _captured.AnsibleResultCapturedError.find_first_remoted_error(exception): + result = ansible_remoted_error._result.copy() + else: + result = {} + + error_summary = _utils._create_error_summary(exception, _traceback.TracebackEvent.ERROR) + + result.update( + failed=True, + exception=error_summary, + ) + + if 'msg' not in result: + result.update(msg=_utils._dedupe_and_concat_message_chain([md.msg for md in error_summary.details])) + + return result diff --git a/lib/ansible/plugins/action/assert.py b/lib/ansible/plugins/action/assert.py index 5e18749af04..55df3873ab8 100644 --- a/lib/ansible/plugins/action/assert.py +++ b/lib/ansible/plugins/action/assert.py @@ -16,19 +16,41 @@ # along with Ansible. If not, see . from __future__ import annotations -from ansible.errors import AnsibleError -from ansible.playbook.conditional import Conditional +import typing as t + +from ansible._internal._templating import _jinja_bits +from ansible.errors import AnsibleTemplateError +from ansible.module_utils.common.validation import _check_type_list_strict from ansible.plugins.action import ActionBase -from ansible.module_utils.six import string_types -from ansible.module_utils.parsing.convert_bool import boolean +from ansible._internal._templating._engine import TemplateEngine class ActionModule(ActionBase): - """ Fail with custom message """ + """Assert that one or more conditional expressions evaluate to true.""" _requires_connection = False - _VALID_ARGS = frozenset(('fail_msg', 'msg', 'quiet', 'success_msg', 'that')) + @classmethod + def finalize_task_arg(cls, name: str, value: t.Any, templar: TemplateEngine, context: t.Any) -> t.Any: + if name != 'that': + # `that` is the only key requiring special handling; delegate to base handling otherwise + return super().finalize_task_arg(name, value, templar, context) + + if not isinstance(value, str): + # if `that` is not a string, we don't need to attempt to resolve it as a template before validation (which will also listify it) + return value + + # if `that` is entirely a string template, we only want to resolve to the container and avoid templating the container contents + if _jinja_bits.is_possibly_all_template(value): + try: + templated_that = templar.resolve_to_container(value) + except AnsibleTemplateError: + pass + else: + if isinstance(templated_that, list): # only use `templated_that` if it is a list + return templated_that + + return value def run(self, tmp=None, task_vars=None): if task_vars is None: @@ -37,49 +59,26 @@ class ActionModule(ActionBase): result = super(ActionModule, self).run(tmp, task_vars) del tmp # tmp no longer has any effect - if 'that' not in self._task.args: - raise AnsibleError('conditional required in "that" string') - - fail_msg = None - success_msg = None - - fail_msg = self._task.args.get('fail_msg', self._task.args.get('msg')) - if fail_msg is None: - fail_msg = 'Assertion failed' - elif isinstance(fail_msg, list): - if not all(isinstance(x, string_types) for x in fail_msg): - raise AnsibleError('Type of one of the elements in fail_msg or msg list is not string type') - elif not isinstance(fail_msg, (string_types, list)): - raise AnsibleError('Incorrect type for fail_msg or msg, expected a string or list and got %s' % type(fail_msg)) - - success_msg = self._task.args.get('success_msg') - if success_msg is None: - success_msg = 'All assertions passed' - elif isinstance(success_msg, list): - if not all(isinstance(x, string_types) for x in success_msg): - raise AnsibleError('Type of one of the elements in success_msg list is not string type') - elif not isinstance(success_msg, (string_types, list)): - raise AnsibleError('Incorrect type for success_msg, expected a string or list and got %s' % type(success_msg)) - - quiet = boolean(self._task.args.get('quiet', False), strict=False) - - # make sure the 'that' items are a list - thats = self._task.args['that'] - if not isinstance(thats, list): - thats = [thats] - - # Now we iterate over the that items, temporarily assigning them - # to the task's when value so we can evaluate the conditional using - # the built in evaluate function. The when has already been evaluated - # by this point, and is not used again, so we don't care about mangling - # that value now - cond = Conditional(loader=self._loader) + validation_result, new_module_args = self.validate_argument_spec( + argument_spec=dict( + fail_msg=dict(type=str_or_list_of_str, aliases=['msg'], default='Assertion failed'), + success_msg=dict(type=str_or_list_of_str, default='All assertions passed'), + quiet=dict(type='bool', default=False), + # explicitly not validating types `elements` here to let type rules for conditionals apply + that=dict(type=_check_type_list_strict, required=True), + ), + ) + + fail_msg = new_module_args['fail_msg'] + success_msg = new_module_args['success_msg'] + quiet = new_module_args['quiet'] + thats = new_module_args['that'] + if not quiet: result['_ansible_verbose_always'] = True for that in thats: - cond.when = [that] - test_result = cond.evaluate_conditional(templar=self._templar, all_vars=task_vars) + test_result = self._templar.evaluate_conditional(conditional=that) if not test_result: result['failed'] = True result['evaluated_to'] = test_result @@ -92,3 +91,13 @@ class ActionModule(ActionBase): result['changed'] = False result['msg'] = success_msg return result + + +def str_or_list_of_str(value: t.Any) -> str | list[str]: + if isinstance(value, str): + return value + + if not isinstance(value, list) or any(not isinstance(item, str) for item in value): + raise TypeError("a string or list of strings is required") + + return value diff --git a/lib/ansible/plugins/action/copy.py b/lib/ansible/plugins/action/copy.py index a6de4b05d32..b8c01ef6b04 100644 --- a/lib/ansible/plugins/action/copy.py +++ b/lib/ansible/plugins/action/copy.py @@ -23,7 +23,6 @@ import os import os.path import stat import tempfile -import traceback from ansible import constants as C from ansible.errors import AnsibleError, AnsibleActionFail, AnsibleFileNotFound @@ -470,10 +469,9 @@ class ActionModule(ActionBase): try: # find in expected paths source = self._find_needle('files', source) - except AnsibleError as e: - result['failed'] = True - result['msg'] = to_text(e) - result['exception'] = traceback.format_exc() + except AnsibleError as ex: + result.update(self.result_dict_from_exception(ex)) + return self._ensure_invocation(result) if trailing_slash != source.endswith(os.path.sep): diff --git a/lib/ansible/plugins/action/debug.py b/lib/ansible/plugins/action/debug.py index eefc2b74a33..55016e5b0b5 100644 --- a/lib/ansible/plugins/action/debug.py +++ b/lib/ansible/plugins/action/debug.py @@ -17,29 +17,32 @@ # along with Ansible. If not, see . from __future__ import annotations -from ansible.errors import AnsibleUndefinedVariable -from ansible.module_utils.six import string_types -from ansible.module_utils.common.text.converters import to_text +from ansible.errors import AnsibleValueOmittedError, AnsibleError +from ansible.module_utils.common.validation import _check_type_str_no_conversion from ansible.plugins.action import ActionBase +from ansible._internal._templating._jinja_common import UndefinedMarker, TruncationMarker +from ansible._internal._templating._utils import Omit +from ansible._internal._templating._marker_behaviors import ReplacingMarkerBehavior, RoutingMarkerBehavior +from ansible.utils.display import Display + +display = Display() class ActionModule(ActionBase): - """ Print statements during execution """ + """ + Emits informational messages, with special diagnostic handling of some templating failures. + """ TRANSFERS_FILES = False - _VALID_ARGS = frozenset(('msg', 'var', 'verbosity')) _requires_connection = False def run(self, tmp=None, task_vars=None): - if task_vars is None: - task_vars = dict() - validation_result, new_module_args = self.validate_argument_spec( - argument_spec={ - 'msg': {'type': 'raw', 'default': 'Hello world!'}, - 'var': {'type': 'raw'}, - 'verbosity': {'type': 'int', 'default': 0}, - }, + argument_spec=dict( + msg=dict(type='raw', default='Hello world!'), + var=dict(type=_check_type_str_no_conversion), + verbosity=dict(type='int', default=0), + ), mutually_exclusive=( ('msg', 'var'), ), @@ -51,31 +54,34 @@ class ActionModule(ActionBase): # get task verbosity verbosity = new_module_args['verbosity'] + replacing_behavior = ReplacingMarkerBehavior() + + var_behavior = RoutingMarkerBehavior({ + UndefinedMarker: replacing_behavior, + TruncationMarker: replacing_behavior, + }) + if verbosity <= self._display.verbosity: - if new_module_args['var']: + if raw_var_arg := new_module_args['var']: + # If var name is same as result, try to template it try: - results = self._templar.template(new_module_args['var'], convert_bare=True, fail_on_undefined=True) - if results == new_module_args['var']: - # if results is not str/unicode type, raise an exception - if not isinstance(results, string_types): - raise AnsibleUndefinedVariable - # If var name is same as result, try to template it - results = self._templar.template("{{" + results + "}}", convert_bare=True, fail_on_undefined=True) - except AnsibleUndefinedVariable as e: - results = u"VARIABLE IS NOT DEFINED!" - if self._display.verbosity > 0: - results += u": %s" % to_text(e) - - if isinstance(new_module_args['var'], (list, dict)): - # If var is a list or dict, use the type as key to display - result[to_text(type(new_module_args['var']))] = results - else: - result[new_module_args['var']] = results + results = self._templar._engine.extend(marker_behavior=var_behavior).evaluate_expression(raw_var_arg) + except AnsibleValueOmittedError as ex: + results = repr(Omit) + display.warning("The result of the `var` expression could not be omitted; a placeholder was used instead.", obj=ex.obj) + except Exception as ex: + raise AnsibleError('Error while resolving `var` expression.', obj=raw_var_arg) from ex + + result[raw_var_arg] = results else: result['msg'] = new_module_args['msg'] # force flag to make debug output module always verbose result['_ansible_verbose_always'] = True + + # propagate any warnings in the task result unless we're skipping the task + replacing_behavior.emit_warnings() + else: result['skipped_reason'] = "Verbosity threshold not met." result['skipped'] = True diff --git a/lib/ansible/plugins/action/dnf.py b/lib/ansible/plugins/action/dnf.py index 137fb13086c..3d36ae2e34e 100644 --- a/lib/ansible/plugins/action/dnf.py +++ b/lib/ansible/plugins/action/dnf.py @@ -30,10 +30,9 @@ class ActionModule(ActionBase): if module in {'yum', 'auto'}: try: - if self._task.delegate_to: # if we delegate, we should use delegated host's facts - module = self._templar.template("{{hostvars['%s']['ansible_facts']['pkg_mgr']}}" % self._task.delegate_to) - else: - module = self._templar.template("{{ansible_facts.pkg_mgr}}") + # if we delegate, we should use delegated host's facts + expr = "hostvars[delegate_to].ansible_facts.pkg_mgr" if self._task.delegate_to else "ansible_facts.pkg_mgr" + module = self._templar.resolve_variable_expression(expr, local_variables=dict(delegate_to=self._task.delegate_to)) except Exception: pass # could not get it from template! diff --git a/lib/ansible/plugins/action/fetch.py b/lib/ansible/plugins/action/fetch.py index 533cab93ec8..133d3315eeb 100644 --- a/lib/ansible/plugins/action/fetch.py +++ b/lib/ansible/plugins/action/fetch.py @@ -51,7 +51,7 @@ class ActionModule(ActionBase): validate_checksum = boolean(self._task.args.get('validate_checksum', True), strict=False) msg = '' - # validate source and dest are strings FIXME: use basic.py and module specs + # FIXME: validate source and dest are strings; use basic.py and module specs if not isinstance(source, string_types): msg = "Invalid type supplied for source option, it must be a string" diff --git a/lib/ansible/plugins/action/gather_facts.py b/lib/ansible/plugins/action/gather_facts.py index 28479cd4deb..11ef07c2380 100644 --- a/lib/ansible/plugins/action/gather_facts.py +++ b/lib/ansible/plugins/action/gather_facts.py @@ -9,7 +9,7 @@ import typing as t from ansible import constants as C from ansible.errors import AnsibleActionFail -from ansible.executor.module_common import get_action_args_with_defaults +from ansible.executor.module_common import _apply_action_arg_defaults from ansible.module_utils.parsing.convert_bool import boolean from ansible.plugins.action import ActionBase from ansible.utils.vars import merge_hash @@ -54,10 +54,7 @@ class ActionModule(ActionBase): fact_module, collection_list=self._task.collections ).resolved_fqcn - mod_args = get_action_args_with_defaults( - resolved_fact_module, mod_args, self._task.module_defaults, self._templar, - action_groups=self._task._parent._play._action_groups - ) + mod_args = _apply_action_arg_defaults(resolved_fact_module, self._task, mod_args, self._templar) return mod_args @@ -132,6 +129,8 @@ class ActionModule(ActionBase): # TODO: use gather_timeout to cut module execution if module itself does not support gather_timeout res = self._execute_module(module_name=fact_module, module_args=mod_args, task_vars=task_vars, wrap_async=False) if res.get('failed', False): + # DTFIX-RELEASE: this trashes the individual failure details and does not work with the new error handling; need to do something to + # invoke per-item error handling- perhaps returning this as a synthetic loop result? failed[fact_module] = res elif res.get('skipped', False): skipped[fact_module] = res @@ -164,6 +163,8 @@ class ActionModule(ActionBase): res = self._execute_module(module_name='ansible.legacy.async_status', module_args=poll_args, task_vars=task_vars, wrap_async=False) if res.get('finished', 0) == 1: if res.get('failed', False): + # DTFIX-RELEASE: this trashes the individual failure details and does not work with the new error handling; need to do something to + # invoke per-item error handling- perhaps returning this as a synthetic loop result? failed[module] = res elif res.get('skipped', False): skipped[module] = res diff --git a/lib/ansible/plugins/action/include_vars.py b/lib/ansible/plugins/action/include_vars.py index 38fe4a9f8e6..3eeef2d9c8d 100644 --- a/lib/ansible/plugins/action/include_vars.py +++ b/lib/ansible/plugins/action/include_vars.py @@ -9,8 +9,9 @@ import pathlib import ansible.constants as C from ansible.errors import AnsibleError +from ansible._internal._datatag._tags import SourceWasEncrypted from ansible.module_utils.six import string_types -from ansible.module_utils.common.text.converters import to_native, to_text +from ansible.module_utils.common.text.converters import to_native from ansible.plugins.action import ActionBase from ansible.utils.vars import combine_vars @@ -167,9 +168,9 @@ class ActionModule(ActionBase): ) self.source_dir = path_to_use else: - if hasattr(self._task._ds, '_data_source'): + if (origin := self._task._origin) and origin.path: # origin.path is not present for ad-hoc tasks current_dir = ( - "/".join(self._task._ds._data_source.split('/')[:-1]) + "/".join(origin.path.split('/')[:-1]) ) self.source_dir = path.join(current_dir, self.source_dir) @@ -233,14 +234,13 @@ class ActionModule(ActionBase): failed = True err_msg = ('{0} does not have a valid extension: {1}'.format(to_native(filename), ', '.join(self.valid_extensions))) else: - b_data, show_content = self._loader._get_file_contents(filename) - data = to_text(b_data, errors='surrogate_or_strict') + data = self._loader.load_from_file(filename, cache='none', trusted_as_template=True) - self.show_content &= show_content # mask all results if any file was encrypted + self.show_content &= not SourceWasEncrypted.is_tagged_on(data) - data = self._loader.load(data, file_name=filename, show_content=show_content) - if not data: + if data is None: # support empty files, but not falsey values data = dict() + if not isinstance(data, dict): failed = True err_msg = ('{0} must be stored as a dictionary/hash'.format(to_native(filename))) diff --git a/lib/ansible/plugins/action/package.py b/lib/ansible/plugins/action/package.py index 13b2cdf7766..97c95115547 100644 --- a/lib/ansible/plugins/action/package.py +++ b/lib/ansible/plugins/action/package.py @@ -17,7 +17,7 @@ from __future__ import annotations from ansible.errors import AnsibleAction, AnsibleActionFail -from ansible.executor.module_common import get_action_args_with_defaults +from ansible.executor.module_common import _apply_action_arg_defaults from ansible.module_utils.facts.system.pkg_mgr import PKG_MGRS from ansible.plugins.action import ActionBase from ansible.utils.display import Display @@ -92,10 +92,7 @@ class ActionModule(ActionBase): # get defaults for specific module context = self._shared_loader_obj.module_loader.find_plugin_with_context(module, collection_list=self._task.collections) - new_module_args = get_action_args_with_defaults( - context.resolved_fqcn, new_module_args, self._task.module_defaults, self._templar, - action_groups=self._task._parent._play._action_groups - ) + new_module_args = _apply_action_arg_defaults(context.resolved_fqcn, self._task, new_module_args, self._templar) if module in self.BUILTIN_PKG_MGR_MODULES: # prefix with ansible.legacy to eliminate external collisions while still allowing library/ override diff --git a/lib/ansible/plugins/action/script.py b/lib/ansible/plugins/action/script.py index b3463d9060b..bb68076c5db 100644 --- a/lib/ansible/plugins/action/script.py +++ b/lib/ansible/plugins/action/script.py @@ -49,9 +49,8 @@ class ActionModule(ActionBase): 'chdir': {'type': 'str'}, 'executable': {'type': 'str'}, }, - required_one_of=[ - ['_raw_params', 'cmd'] - ] + required_one_of=[['_raw_params', 'cmd']], + mutually_exclusive=[['_raw_params', 'cmd']], ) result = super(ActionModule, self).run(tmp, task_vars) @@ -89,7 +88,7 @@ class ActionModule(ActionBase): # Split out the script as the first item in raw_params using # shlex.split() in order to support paths and files with spaces in the name. # Any arguments passed to the script will be added back later. - raw_params = to_native(new_module_args.get('_raw_params', ''), errors='surrogate_or_strict') + raw_params = new_module_args['_raw_params'] or new_module_args['cmd'] parts = [to_text(s, errors='surrogate_or_strict') for s in shlex.split(raw_params.strip())] source = parts[0] @@ -162,6 +161,7 @@ class ActionModule(ActionBase): become_plugin=self._connection.become, substyle="script", task_vars=task_vars, + profile='legacy', # the profile doesn't really matter since the module args dict is empty ) # build the necessary exec wrapper command # FUTURE: this still doesn't let script work on Windows with non-pipelined connections or diff --git a/lib/ansible/plugins/action/service.py b/lib/ansible/plugins/action/service.py index 2b00d10b9d3..30fe897b040 100644 --- a/lib/ansible/plugins/action/service.py +++ b/lib/ansible/plugins/action/service.py @@ -16,9 +16,8 @@ # along with Ansible. If not, see . from __future__ import annotations - from ansible.errors import AnsibleAction, AnsibleActionFail -from ansible.executor.module_common import get_action_args_with_defaults +from ansible.executor.module_common import _apply_action_arg_defaults from ansible.plugins.action import ActionBase @@ -47,10 +46,9 @@ class ActionModule(ActionBase): if module == 'auto': try: - if self._task.delegate_to: # if we delegate, we should use delegated host's facts - module = self._templar.template("{{hostvars['%s']['ansible_facts']['service_mgr']}}" % self._task.delegate_to) - else: - module = self._templar.template('{{ansible_facts.service_mgr}}') + # if we delegate, we should use delegated host's facts + expr = "hostvars[delegate_to].ansible_facts.service_mgr" if self._task.delegate_to else "ansible_facts.service_mgr" + module = self._templar.resolve_variable_expression(expr, local_variables=dict(delegate_to=self._task.delegate_to)) except Exception: pass # could not get it from template! @@ -79,10 +77,7 @@ class ActionModule(ActionBase): # get defaults for specific module context = self._shared_loader_obj.module_loader.find_plugin_with_context(module, collection_list=self._task.collections) - new_module_args = get_action_args_with_defaults( - context.resolved_fqcn, new_module_args, self._task.module_defaults, self._templar, - action_groups=self._task._parent._play._action_groups - ) + new_module_args = _apply_action_arg_defaults(context.resolved_fqcn, self._task, new_module_args, self._templar) # collection prefix known internal modules to avoid collisions from collections search, while still allowing library/ overrides if module in self.BUILTIN_SVC_MGR_MODULES: diff --git a/lib/ansible/plugins/action/set_fact.py b/lib/ansible/plugins/action/set_fact.py index b95ec4940f9..62921aed676 100644 --- a/lib/ansible/plugins/action/set_fact.py +++ b/lib/ansible/plugins/action/set_fact.py @@ -18,12 +18,9 @@ from __future__ import annotations from ansible.errors import AnsibleActionFail -from ansible.module_utils.six import string_types from ansible.module_utils.parsing.convert_bool import boolean from ansible.plugins.action import ActionBase -from ansible.utils.vars import isidentifier - -import ansible.constants as C +from ansible.utils.vars import validate_variable_name class ActionModule(ActionBase): @@ -43,16 +40,10 @@ class ActionModule(ActionBase): if self._task.args: for (k, v) in self._task.args.items(): - k = self._templar.template(k) + k = self._templar.template(k) # a rare case where key templating is allowed; backward-compatibility for dynamic storage - if not isidentifier(k): - raise AnsibleActionFail("The variable name '%s' is not valid. Variables must start with a letter or underscore character, " - "and contain only letters, numbers and underscores." % k) + validate_variable_name(k) - # NOTE: this should really use BOOLEANS from convert_bool, but only in the k=v case, - # right now it converts matching explicit YAML strings also when 'jinja2_native' is disabled. - if not C.DEFAULT_JINJA2_NATIVE and isinstance(v, string_types) and v.lower() in ('true', 'false', 'yes', 'no'): - v = boolean(v, strict=False) facts[k] = v else: raise AnsibleActionFail('No key/value pairs provided, at least one is required for this action to succeed') diff --git a/lib/ansible/plugins/action/set_stats.py b/lib/ansible/plugins/action/set_stats.py index 309180f7a3d..bb312000ec3 100644 --- a/lib/ansible/plugins/action/set_stats.py +++ b/lib/ansible/plugins/action/set_stats.py @@ -19,7 +19,7 @@ from __future__ import annotations from ansible.module_utils.parsing.convert_bool import boolean from ansible.plugins.action import ActionBase -from ansible.utils.vars import isidentifier +from ansible.utils.vars import validate_variable_name class ActionModule(ActionBase): @@ -42,7 +42,7 @@ class ActionModule(ActionBase): data = self._task.args.get('data', {}) if not isinstance(data, dict): - data = self._templar.template(data, convert_bare=False, fail_on_undefined=True) + data = self._templar.template(data) if not isinstance(data, dict): result['failed'] = True @@ -59,14 +59,9 @@ class ActionModule(ActionBase): stats[opt] = val for (k, v) in data.items(): - k = self._templar.template(k) - if not isidentifier(k): - result['failed'] = True - result['msg'] = ("The variable name '%s' is not valid. Variables must start with a letter or underscore character, and contain only " - "letters, numbers and underscores." % k) - return result + validate_variable_name(k) stats['data'][k] = self._templar.template(v) diff --git a/lib/ansible/plugins/action/template.py b/lib/ansible/plugins/action/template.py index f83522dd70d..8a306d235c4 100644 --- a/lib/ansible/plugins/action/template.py +++ b/lib/ansible/plugins/action/template.py @@ -20,12 +20,12 @@ from jinja2.defaults import ( from ansible import constants as C from ansible.config.manager import ensure_type -from ansible.errors import AnsibleError, AnsibleFileNotFound, AnsibleAction, AnsibleActionFail +from ansible.errors import AnsibleError, AnsibleAction, AnsibleActionFail from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.six import string_types from ansible.plugins.action import ActionBase -from ansible.template import generate_ansible_template_vars, AnsibleEnvironment +from ansible.template import generate_ansible_template_vars, trust_as_template class ActionModule(ActionBase): @@ -98,63 +98,39 @@ class ActionModule(ActionBase): if mode == 'preserve': mode = '0%03o' % stat.S_IMODE(os.stat(source).st_mode) - # Get vault decrypted tmp file - try: - tmp_source = self._loader.get_real_file(source) - except AnsibleFileNotFound as e: - raise AnsibleActionFail("could not find src=%s, %s" % (source, to_text(e))) - b_tmp_source = to_bytes(tmp_source, errors='surrogate_or_strict') - # template the source data locally & get ready to transfer - try: - with open(b_tmp_source, 'rb') as f: - try: - template_data = to_text(f.read(), errors='surrogate_or_strict') - except UnicodeError: - raise AnsibleActionFail("Template source files must be utf-8 encoded") - - # set jinja2 internal search path for includes - searchpath = task_vars.get('ansible_search_path', []) - searchpath.extend([self._loader._basedir, os.path.dirname(source)]) - - # We want to search into the 'templates' subdir of each search path in - # addition to our original search paths. - newsearchpath = [] - for p in searchpath: - newsearchpath.append(os.path.join(p, 'templates')) - newsearchpath.append(p) - searchpath = newsearchpath - - # add ansible 'template' vars - temp_vars = task_vars.copy() - # NOTE in the case of ANSIBLE_DEBUG=1 task_vars is VarsWithSources(MutableMapping) - # so | operator cannot be used as it can be used only on dicts - # https://peps.python.org/pep-0584/#what-about-mapping-and-mutablemapping - temp_vars.update(generate_ansible_template_vars(self._task.args.get('src', None), source, dest)) - - # force templar to use AnsibleEnvironment to prevent issues with native types - # https://github.com/ansible/ansible/issues/46169 - templar = self._templar.copy_with_new_env(environment_class=AnsibleEnvironment, - searchpath=searchpath, - newline_sequence=newline_sequence, - available_variables=temp_vars) - overrides = dict( - block_start_string=block_start_string, - block_end_string=block_end_string, - variable_start_string=variable_start_string, - variable_end_string=variable_end_string, - comment_start_string=comment_start_string, - comment_end_string=comment_end_string, - trim_blocks=trim_blocks, - lstrip_blocks=lstrip_blocks - ) - resultant = templar.do_template(template_data, preserve_trailing_newlines=True, escape_backslashes=False, overrides=overrides) - except AnsibleAction: - raise - except Exception as e: - raise AnsibleActionFail("%s: %s" % (type(e).__name__, to_text(e))) - finally: - self._loader.cleanup_tmp_file(b_tmp_source) + template_data = trust_as_template(self._loader.get_text_file_contents(source)) + + # set jinja2 internal search path for includes + searchpath = task_vars.get('ansible_search_path', []) + searchpath.extend([self._loader._basedir, os.path.dirname(source)]) + + # We want to search into the 'templates' subdir of each search path in + # addition to our original search paths. + newsearchpath = [] + for p in searchpath: + newsearchpath.append(os.path.join(p, 'templates')) + newsearchpath.append(p) + searchpath = newsearchpath + + # add ansible 'template' vars + temp_vars = task_vars.copy() + temp_vars.update(generate_ansible_template_vars(self._task.args.get('src', None), fullpath=source, dest_path=dest)) + + overrides = dict( + block_start_string=block_start_string, + block_end_string=block_end_string, + variable_start_string=variable_start_string, + variable_end_string=variable_end_string, + comment_start_string=comment_start_string, + comment_end_string=comment_end_string, + trim_blocks=trim_blocks, + lstrip_blocks=lstrip_blocks, + newline_sequence=newline_sequence, + ) + + data_templar = self._templar.copy_with_new_env(searchpath=searchpath, available_variables=temp_vars) + resultant = data_templar.template(template_data, escape_backslashes=False, overrides=overrides) new_task = self._task.copy() # mode is either the mode from task.args or the mode of the source file if the task.args diff --git a/lib/ansible/plugins/cache/__init__.py b/lib/ansible/plugins/cache/__init__.py index 3bc5a16f303..40518d84c7a 100644 --- a/lib/ansible/plugins/cache/__init__.py +++ b/lib/ansible/plugins/cache/__init__.py @@ -22,14 +22,15 @@ import errno import os import tempfile import time +import typing as t from abc import abstractmethod -from collections.abc import MutableMapping +from collections import abc as c from ansible import constants as C from ansible.errors import AnsibleError from ansible.module_utils.common.file import S_IRWU_RG_RO -from ansible.module_utils.common.text.converters import to_bytes, to_text +from ansible.module_utils.common.text.converters import to_bytes from ansible.plugins import AnsiblePlugin from ansible.plugins.loader import cache_loader from ansible.utils.collection_loader import resource_from_fqcr @@ -42,37 +43,36 @@ class BaseCacheModule(AnsiblePlugin): # Backwards compat only. Just import the global display instead _display = display + _persistent = True + """Plugins that do not persist data between runs can set False to bypass schema-version key munging and JSON serialization wrapper.""" - def __init__(self, *args, **kwargs): - super(BaseCacheModule, self).__init__() - self.set_options(var_options=args, direct=kwargs) + def __init__(self, *args, **kwargs) -> None: + super().__init__() - @abstractmethod - def get(self, key): - pass + self.set_options(var_options=args, direct=kwargs) @abstractmethod - def set(self, key, value): + def get(self, key: str) -> dict[str, object]: pass @abstractmethod - def keys(self): + def set(self, key: str, value: dict[str, object]) -> None: pass @abstractmethod - def contains(self, key): + def keys(self) -> t.Sequence[str]: pass @abstractmethod - def delete(self, key): + def contains(self, key: object) -> bool: pass @abstractmethod - def flush(self): + def delete(self, key: str) -> None: pass @abstractmethod - def copy(self): + def flush(self) -> None: pass @@ -116,7 +116,7 @@ class BaseFileCacheModule(BaseCacheModule): raise AnsibleError("error in '%s' cache, configured path (%s) does not have necessary permissions (rwx), disabling plugin" % ( self.plugin_name, self._cache_dir)) - def _get_cache_file_name(self, key): + def _get_cache_file_name(self, key: str) -> str: prefix = self.get_option('_prefix') if prefix: cachefile = "%s/%s%s" % (self._cache_dir, prefix, key) @@ -144,11 +144,10 @@ class BaseFileCacheModule(BaseCacheModule): self.delete(key) raise AnsibleError("The cache file %s was corrupt, or did not otherwise contain valid data. " "It has been removed, so you can re-run your command now." % cachefile) - except (OSError, IOError) as e: - display.warning("error in '%s' cache plugin while trying to read %s : %s" % (self.plugin_name, cachefile, to_bytes(e))) + except FileNotFoundError: raise KeyError - except Exception as e: - raise AnsibleError("Error while decoding the cache file %s: %s" % (cachefile, to_bytes(e))) + except Exception as ex: + raise AnsibleError(f"Error while accessing the cache file {cachefile!r}.") from ex return self._cache.get(key) @@ -245,14 +244,8 @@ class BaseFileCacheModule(BaseCacheModule): for key in self.keys(): self.delete(key) - def copy(self): - ret = dict() - for key in self.keys(): - ret[key] = self.get(key) - return ret - @abstractmethod - def _load(self, filepath): + def _load(self, filepath: str) -> object: """ Read data from a filepath and return it as a value @@ -271,7 +264,7 @@ class BaseFileCacheModule(BaseCacheModule): pass @abstractmethod - def _dump(self, value, filepath): + def _dump(self, value: object, filepath: str) -> None: """ Write data to a filepath @@ -281,19 +274,13 @@ class BaseFileCacheModule(BaseCacheModule): pass -class CachePluginAdjudicator(MutableMapping): - """ - Intermediary between a cache dictionary and a CacheModule - """ +class CachePluginAdjudicator(c.MutableMapping): + """Batch update wrapper around a cache plugin.""" + def __init__(self, plugin_name='memory', **kwargs): self._cache = {} self._retrieved = {} - self._plugin = cache_loader.get(plugin_name, **kwargs) - if not self._plugin: - raise AnsibleError('Unable to load the cache plugin (%s).' % plugin_name) - - self._plugin_name = plugin_name def update_cache_if_changed(self): if self._retrieved != self._cache: @@ -302,6 +289,7 @@ class CachePluginAdjudicator(MutableMapping): def set_cache(self): for top_level_cache_key in self._cache.keys(): self._plugin.set(top_level_cache_key, self._cache[top_level_cache_key]) + self._retrieved = copy.deepcopy(self._cache) def load_whole_cache(self): @@ -309,7 +297,7 @@ class CachePluginAdjudicator(MutableMapping): self._cache[key] = self._plugin.get(key) def __repr__(self): - return to_text(self._cache) + return repr(self._cache) def __iter__(self): return iter(self.keys()) @@ -319,13 +307,10 @@ class CachePluginAdjudicator(MutableMapping): def _do_load_key(self, key): load = False - if all([ - key not in self._cache, - key not in self._retrieved, - self._plugin_name != 'memory', - self._plugin.contains(key), - ]): + + if key not in self._cache and key not in self._retrieved and self._plugin._persistent and self._plugin.contains(key): load = True + return load def __getitem__(self, key): @@ -336,16 +321,18 @@ class CachePluginAdjudicator(MutableMapping): pass else: self._retrieved[key] = self._cache[key] + return self._cache[key] def get(self, key, default=None): if self._do_load_key(key): try: self._cache[key] = self._plugin.get(key) - except KeyError as e: + except KeyError: pass else: self._retrieved[key] = self._cache[key] + return self._cache.get(key, default) def items(self): @@ -360,6 +347,7 @@ class CachePluginAdjudicator(MutableMapping): def pop(self, key, *args): if args: return self._cache.pop(key, args[0]) + return self._cache.pop(key) def __delitem__(self, key): @@ -368,6 +356,9 @@ class CachePluginAdjudicator(MutableMapping): def __setitem__(self, key, value): self._cache[key] = value + def clear(self): + self.flush() + def flush(self): self._plugin.flush() self._cache = {} diff --git a/lib/ansible/plugins/cache/base.py b/lib/ansible/plugins/cache/base.py index a7c7468b820..837365d9b4a 100644 --- a/lib/ansible/plugins/cache/base.py +++ b/lib/ansible/plugins/cache/base.py @@ -18,3 +18,11 @@ from __future__ import annotations # moved actual classes to __init__ kept here for backward compat with 3rd parties from ansible.plugins.cache import BaseCacheModule, BaseFileCacheModule # pylint: disable=unused-import + +from ansible.utils.display import Display as _Display + +_Display().deprecated( + msg="The `ansible.plugins.cache.base` Python module is deprecated.", + help_text="Import from `ansible.plugins.cache` instead.", + version="2.23", +) diff --git a/lib/ansible/plugins/cache/jsonfile.py b/lib/ansible/plugins/cache/jsonfile.py index 6184947b6c9..00ead7c77c6 100644 --- a/lib/ansible/plugins/cache/jsonfile.py +++ b/lib/ansible/plugins/cache/jsonfile.py @@ -40,23 +40,17 @@ DOCUMENTATION = """ type: integer """ -import codecs import json +import pathlib -from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder from ansible.plugins.cache import BaseFileCacheModule class CacheModule(BaseFileCacheModule): - """ - A caching module backed by json files. - """ - - def _load(self, filepath): - # Valid JSON is always UTF-8 encoded. - with codecs.open(filepath, 'r', encoding='utf-8') as f: - return json.load(f, cls=AnsibleJSONDecoder) - - def _dump(self, value, filepath): - with codecs.open(filepath, 'w', encoding='utf-8') as f: - f.write(json.dumps(value, cls=AnsibleJSONEncoder, sort_keys=True, indent=4)) + """A caching module backed by json files.""" + + def _load(self, filepath: str) -> object: + return json.loads(pathlib.Path(filepath).read_text()) + + def _dump(self, value: object, filepath: str) -> None: + pathlib.Path(filepath).write_text(json.dumps(value)) diff --git a/lib/ansible/plugins/cache/memory.py b/lib/ansible/plugins/cache/memory.py index 780a643f151..055860da6ef 100644 --- a/lib/ansible/plugins/cache/memory.py +++ b/lib/ansible/plugins/cache/memory.py @@ -20,12 +20,15 @@ from ansible.plugins.cache import BaseCacheModule class CacheModule(BaseCacheModule): + _persistent = False # prevent unnecessary JSON serialization and key munging def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cache = {} def get(self, key): - return self._cache.get(key) + return self._cache[key] def set(self, key, value): self._cache[key] = value @@ -41,12 +44,3 @@ class CacheModule(BaseCacheModule): def flush(self): self._cache = {} - - def copy(self): - return self._cache.copy() - - def __getstate__(self): - return self.copy() - - def __setstate__(self, data): - self._cache = data diff --git a/lib/ansible/plugins/callback/__init__.py b/lib/ansible/plugins/callback/__init__.py index 8dd839fdc8f..f88055a4daa 100644 --- a/lib/ansible/plugins/callback/__init__.py +++ b/lib/ansible/plugins/callback/__init__.py @@ -18,27 +18,28 @@ from __future__ import annotations import difflib +import functools import json import re import sys import textwrap +import typing as t + from typing import TYPE_CHECKING -from collections import OrderedDict from collections.abc import MutableMapping from copy import deepcopy from ansible import constants as C -from ansible.module_utils.common.text.converters import to_text -from ansible.module_utils.six import text_type -from ansible.parsing.ajson import AnsibleJSONEncoder -from ansible.parsing.yaml.dumper import AnsibleDumper -from ansible.parsing.yaml.objects import AnsibleUnicode +from ansible.module_utils._internal import _datatag +from ansible.module_utils.common.messages import ErrorSummary +from ansible._internal._yaml import _dumper from ansible.plugins import AnsiblePlugin from ansible.utils.color import stringc from ansible.utils.display import Display -from ansible.utils.unsafe_proxy import AnsibleUnsafeText, NativeJinjaUnsafeText from ansible.vars.clean import strip_internal_keys, module_response_deepcopy +from ansible.module_utils._internal._json._profiles import _fallback_to_str +from ansible._internal._templating import _engine import yaml @@ -52,23 +53,41 @@ __all__ = ["CallbackBase"] _DEBUG_ALLOWED_KEYS = frozenset(('msg', 'exception', 'warnings', 'deprecations')) -_YAML_TEXT_TYPES = (text_type, AnsibleUnicode, AnsibleUnsafeText, NativeJinjaUnsafeText) # Characters that libyaml/pyyaml consider breaks _YAML_BREAK_CHARS = '\n\x85\u2028\u2029' # NL, NEL, LS, PS # regex representation of libyaml/pyyaml of a space followed by a break character _SPACE_BREAK_RE = re.compile(fr' +([{_YAML_BREAK_CHARS}])') -class _AnsibleCallbackDumper(AnsibleDumper): - def __init__(self, lossy=False): +class _AnsibleCallbackDumper(_dumper.AnsibleDumper): + def __init__(self, *args, lossy: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._lossy = lossy - def __call__(self, *args, **kwargs): - # pyyaml expects that we are passing an object that can be instantiated, but to - # smuggle the ``lossy`` configuration, we do that in ``__init__`` and then - # define this ``__call__`` that will mimic the ability for pyyaml to instantiate class - super().__init__(*args, **kwargs) - return self + def _pretty_represent_str(self, data): + """Uses block style for multi-line strings""" + data = _datatag.AnsibleTagHelper.as_native_type(data) + + if _should_use_block(data): + style = '|' + if self._lossy: + data = _munge_data_for_lossy_yaml(data) + else: + style = self.default_style + + node = yaml.representer.ScalarNode('tag:yaml.org,2002:str', data, style=style) + + if self.alias_key is not None: + self.represented_objects[self.alias_key] = node + + return node + + @classmethod + def _register_representers(cls) -> None: + super()._register_representers() + + cls.add_multi_representer(str, cls._pretty_represent_str) def _should_use_block(scalar): @@ -77,6 +96,7 @@ def _should_use_block(scalar): for ch in _YAML_BREAK_CHARS: if ch in scalar: return True + return False @@ -95,12 +115,12 @@ class _SpecialCharacterTranslator: return None -def _filter_yaml_special(scalar): +def _filter_yaml_special(scalar: str) -> str: """Filter a string removing any character that libyaml/pyyaml declare as special""" return scalar.translate(_SpecialCharacterTranslator()) -def _munge_data_for_lossy_yaml(scalar): +def _munge_data_for_lossy_yaml(scalar: str) -> str: """Modify a string so that analyze_scalar in libyaml/pyyaml will allow block formatting""" # we care more about readability than accuracy, so... # ...libyaml/pyyaml does not permit trailing spaces for block scalars @@ -113,31 +133,7 @@ def _munge_data_for_lossy_yaml(scalar): return _SPACE_BREAK_RE.sub(r'\1', scalar) -def _pretty_represent_str(self, data): - """Uses block style for multi-line strings""" - data = text_type(data) - if _should_use_block(data): - style = '|' - if self._lossy: - data = _munge_data_for_lossy_yaml(data) - else: - style = self.default_style - - node = yaml.representer.ScalarNode('tag:yaml.org,2002:str', data, style=style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - return node - - -for data_type in _YAML_TEXT_TYPES: - _AnsibleCallbackDumper.add_representer( - data_type, - _pretty_represent_str - ) - - class CallbackBase(AnsiblePlugin): - """ This is a base ansible callback class that does nothing. New callbacks should use this class as a base and override any callback methods they wish to execute @@ -244,9 +240,12 @@ class CallbackBase(AnsiblePlugin): if self._display.verbosity < 3 and 'diff' in result: del abridged_result['diff'] - # remove exception from screen output - if 'exception' in abridged_result: - del abridged_result['exception'] + # remove error/warning values; the stdout callback should have already handled them + abridged_result.pop('exception', None) + abridged_result.pop('warnings', None) + abridged_result.pop('deprecations', None) + + abridged_result = _engine.TemplateEngine().transform(abridged_result) # ensure the dumped view matches the transformed view a playbook sees if not serialize: # Just return ``abridged_result`` without going through serialization @@ -255,17 +254,8 @@ class CallbackBase(AnsiblePlugin): return abridged_result if result_format == 'json': - try: - return json.dumps(abridged_result, cls=AnsibleJSONEncoder, indent=indent, ensure_ascii=False, sort_keys=sort_keys) - except TypeError: - # Python3 bug: throws an exception when keys are non-homogenous types: - # https://bugs.python.org/issue25457 - # sort into an OrderedDict and then json.dumps() that instead - if not OrderedDict: - raise - return json.dumps(OrderedDict(sorted(abridged_result.items(), key=to_text)), - cls=AnsibleJSONEncoder, indent=indent, - ensure_ascii=False, sort_keys=False) + return json.dumps(abridged_result, cls=_fallback_to_str.Encoder, indent=indent, ensure_ascii=False, sort_keys=sort_keys) + elif result_format == 'yaml': # None is a sentinel in this case that indicates default behavior # default behavior for yaml is to prettify results @@ -283,7 +273,7 @@ class CallbackBase(AnsiblePlugin): yaml.dump( abridged_result, allow_unicode=True, - Dumper=_AnsibleCallbackDumper(lossy=lossy), + Dumper=functools.partial(_AnsibleCallbackDumper, lossy=lossy), default_flow_style=False, indent=indent, # sort_keys=sort_keys # This requires PyYAML>=5.1 @@ -291,32 +281,31 @@ class CallbackBase(AnsiblePlugin): ' ' * (indent or 4) ) - def _handle_warnings(self, res): - """ display warnings, if enabled and any exist in the result """ - if C.ACTION_WARNINGS: - if 'warnings' in res and res['warnings']: - for warning in res['warnings']: - self._display.warning(warning) - del res['warnings'] - if 'deprecations' in res and res['deprecations']: - for warning in res['deprecations']: - self._display.deprecated(**warning) - del res['deprecations'] - - def _handle_exception(self, result, use_stderr=False): - - if 'exception' in result: - msg = "An exception occurred during task execution. " - exception_str = to_text(result['exception']) - if self._display.verbosity < 3: - # extract just the actual error message from the exception text - error = exception_str.strip().split('\n')[-1] - msg += "To see the full traceback, use -vvv. The error was: %s" % error - else: - msg = "The full traceback is:\n" + exception_str - del result['exception'] + def _handle_warnings(self, res: dict[str, t.Any]) -> None: + """Display warnings and deprecation warnings sourced by task execution.""" + for warning in res.pop('warnings', []): + # DTFIX-RELEASE: what to do about propagating wrap_text from the original display.warning call? + self._display._warning(warning, wrap_text=False) + + for warning in res.pop('deprecations', []): + self._display._deprecated(warning) + + def _handle_exception(self, result: dict[str, t.Any], use_stderr: bool = False) -> None: + error_summary: ErrorSummary | None + + if error_summary := result.pop('exception', None): + self._display._error(error_summary, wrap_text=False, stderr=use_stderr) + + def _handle_warnings_and_exception(self, result: TaskResult) -> None: + """Standardized handling of warnings/deprecations and exceptions from a task/item result.""" + # DTFIX-RELEASE: make/doc/porting-guide a public version of this method? + try: + use_stderr = self.get_option('display_failed_stderr') + except KeyError: + use_stderr = False - self._display.display(msg, color=C.COLOR_ERROR, stderr=use_stderr) + self._handle_warnings(result._result) + self._handle_exception(result._result, use_stderr=use_stderr) def _serialize_diff(self, diff): try: @@ -341,7 +330,7 @@ class CallbackBase(AnsiblePlugin): yaml.dump( diff, allow_unicode=True, - Dumper=_AnsibleCallbackDumper(lossy=lossy), + Dumper=functools.partial(_AnsibleCallbackDumper, lossy=lossy), default_flow_style=False, indent=4, # sort_keys=sort_keys # This requires PyYAML>=5.1 @@ -425,6 +414,7 @@ class CallbackBase(AnsiblePlugin): """ removes data from results for display """ # mostly controls that debug only outputs what it was meant to + # FIXME: this is a terrible heuristic to format debug's output- it masks exception detail if task_name in C._ACTION_DEBUG: if 'msg' in result: # msg should be alone @@ -659,13 +649,13 @@ class CallbackBase(AnsiblePlugin): def v2_playbook_on_include(self, included_file): pass # no v1 correspondence - def v2_runner_item_on_ok(self, result): + def v2_runner_item_on_ok(self, result: TaskResult) -> None: pass - def v2_runner_item_on_failed(self, result): + def v2_runner_item_on_failed(self, result: TaskResult) -> None: pass - def v2_runner_item_on_skipped(self, result): + def v2_runner_item_on_skipped(self, result: TaskResult) -> None: pass def v2_runner_retry(self, result): diff --git a/lib/ansible/plugins/callback/default.py b/lib/ansible/plugins/callback/default.py index 39bd5a45f39..2237c73a759 100644 --- a/lib/ansible/plugins/callback/default.py +++ b/lib/ansible/plugins/callback/default.py @@ -21,6 +21,7 @@ DOCUMENTATION = """ from ansible import constants as C from ansible import context +from ansible.executor.task_result import TaskResult from ansible.playbook.task_include import TaskInclude from ansible.plugins.callback import CallbackBase from ansible.utils.color import colorize, hostcolor @@ -46,20 +47,20 @@ class CallbackModule(CallbackBase): self._task_type_cache = {} super(CallbackModule, self).__init__() - def v2_runner_on_failed(self, result, ignore_errors=False): - + def v2_runner_on_failed(self, result: TaskResult, ignore_errors: bool = False) -> None: host_label = self.host_label(result) - self._clean_results(result._result, result._task.action) if self._last_task_banner != result._task._uuid: self._print_task_banner(result._task) - self._handle_exception(result._result, use_stderr=self.get_option('display_failed_stderr')) - self._handle_warnings(result._result) + self._handle_warnings_and_exception(result) + + # FIXME: this method should not exist, delegate "suggested keys to display" to the plugin or something... As-is, the placement of this + # call obliterates `results`, which causes a task summary to be printed on loop failures, which we don't do anywhere else. + self._clean_results(result._result, result._task.action) if result._task.loop and 'results' in result._result: self._process_items(result) - else: if self._display.verbosity < 2 and self.get_option('show_task_path_on_failure'): self._print_task_path(result._task) @@ -69,8 +70,7 @@ class CallbackModule(CallbackBase): if ignore_errors: self._display.display("...ignoring", color=C.COLOR_SKIP) - def v2_runner_on_ok(self, result): - + def v2_runner_on_ok(self, result: TaskResult) -> None: host_label = self.host_label(result) if isinstance(result._task, TaskInclude): @@ -93,7 +93,7 @@ class CallbackModule(CallbackBase): msg = "ok: [%s]" % (host_label,) color = C.COLOR_OK - self._handle_warnings(result._result) + self._handle_warnings_and_exception(result) if result._task.loop and 'results' in result._result: self._process_items(result) @@ -104,8 +104,7 @@ class CallbackModule(CallbackBase): msg += " => %s" % (self._dump_results(result._result),) self._display.display(msg, color=color) - def v2_runner_on_skipped(self, result): - + def v2_runner_on_skipped(self, result: TaskResult) -> None: if self.get_option('display_skipped_hosts'): self._clean_results(result._result, result._task.action) @@ -113,6 +112,8 @@ class CallbackModule(CallbackBase): if self._last_task_banner != result._task._uuid: self._print_task_banner(result._task) + self._handle_warnings_and_exception(result) + if result._task.loop is not None and 'results' in result._result: self._process_items(result) @@ -121,10 +122,12 @@ class CallbackModule(CallbackBase): msg += " => %s" % self._dump_results(result._result) self._display.display(msg, color=C.COLOR_SKIP) - def v2_runner_on_unreachable(self, result): + def v2_runner_on_unreachable(self, result: TaskResult) -> None: if self._last_task_banner != result._task._uuid: self._print_task_banner(result._task) + self._handle_warnings_and_exception(result) + host_label = self.host_label(result) msg = "fatal: [%s]: UNREACHABLE! => %s" % (host_label, self._dump_results(result._result)) self._display.display(msg, color=C.COLOR_UNREACHABLE, stderr=self.get_option('display_failed_stderr')) @@ -171,6 +174,7 @@ class CallbackModule(CallbackBase): # that they can secure this if they feel that their stdout is insecure # (shoulder surfing, logging stdout straight to a file, etc). args = '' + # FIXME: the no_log value is not templated at this point, so any template will be considered truthy if not task.no_log and C.DISPLAY_ARGS_TO_STDOUT: args = u', '.join(u'%s=%s' % a for a in task.args.items()) args = u' %s' % args @@ -234,8 +238,7 @@ class CallbackModule(CallbackBase): self._print_task_banner(result._task) self._display.display(diff) - def v2_runner_item_on_ok(self, result): - + def v2_runner_item_on_ok(self, result: TaskResult) -> None: host_label = self.host_label(result) if isinstance(result._task, TaskInclude): return @@ -255,33 +258,37 @@ class CallbackModule(CallbackBase): msg = 'ok' color = C.COLOR_OK + self._handle_warnings_and_exception(result) + msg = "%s: [%s] => (item=%s)" % (msg, host_label, self._get_item_label(result._result)) self._clean_results(result._result, result._task.action) if self._run_is_verbose(result): msg += " => %s" % self._dump_results(result._result) self._display.display(msg, color=color) - def v2_runner_item_on_failed(self, result): + def v2_runner_item_on_failed(self, result: TaskResult) -> None: if self._last_task_banner != result._task._uuid: self._print_task_banner(result._task) + self._handle_warnings_and_exception(result) + host_label = self.host_label(result) - self._clean_results(result._result, result._task.action) - self._handle_exception(result._result, use_stderr=self.get_option('display_failed_stderr')) msg = "failed: [%s]" % (host_label,) - self._handle_warnings(result._result) + self._clean_results(result._result, result._task.action) self._display.display( msg + " (item=%s) => %s" % (self._get_item_label(result._result), self._dump_results(result._result)), color=C.COLOR_ERROR, stderr=self.get_option('display_failed_stderr') ) - def v2_runner_item_on_skipped(self, result): + def v2_runner_item_on_skipped(self, result: TaskResult) -> None: if self.get_option('display_skipped_hosts'): if self._last_task_banner != result._task._uuid: self._print_task_banner(result._task) + self._handle_warnings_and_exception(result) + self._clean_results(result._result, result._task.action) msg = "skipping: [%s] => (item=%s) " % (result._host.get_name(), self._get_item_label(result._result)) if self._run_is_verbose(result): diff --git a/lib/ansible/plugins/callback/junit.py b/lib/ansible/plugins/callback/junit.py index e164902474f..dc56ac5d1b4 100644 --- a/lib/ansible/plugins/callback/junit.py +++ b/lib/ansible/plugins/callback/junit.py @@ -82,12 +82,15 @@ DOCUMENTATION = """ - enable in configuration """ +import decimal import os import time import re -from ansible import constants as C +from ansible import constants +from ansible.module_utils.common.messages import ErrorSummary from ansible.module_utils.common.text.converters import to_bytes, to_text +from ansible.playbook.task import Task from ansible.plugins.callback import CallbackBase from ansible.utils._junit_xml import ( TestCase, @@ -126,7 +129,7 @@ class CallbackModule(CallbackBase): Default: True JUNIT_HIDE_TASK_ARGUMENTS (optional): Hide the arguments for a task Default: False - JUNIT_TEST_CASE_PREFIX (optional): Consider a task only as test case if it has this value as prefix. Additionally failing tasks are recorded as failed + JUNIT_TEST_CASE_PREFIX (optional): Consider a task only as test case if it has this value as prefix. Additionally, failing tasks are recorded as failed test cases. Default: """ @@ -136,7 +139,7 @@ class CallbackModule(CallbackBase): CALLBACK_NAME = 'junit' CALLBACK_NEEDS_ENABLED = True - def __init__(self): + def __init__(self) -> None: super(CallbackModule, self).__init__() self._output_dir = os.getenv('JUNIT_OUTPUT_DIR', os.path.expanduser('~/.ansible.log')) @@ -150,20 +153,18 @@ class CallbackModule(CallbackBase): self._replace_out_of_tree_path = os.getenv('JUNIT_REPLACE_OUT_OF_TREE_PATH', None) self._playbook_path = None self._playbook_name = None - self._play_name = None - self._task_data = None + self._play_name: str | None = None + self._task_data: dict[str, TaskData] = {} self.disabled = False - self._task_data = {} - if self._replace_out_of_tree_path is not None: self._replace_out_of_tree_path = to_text(self._replace_out_of_tree_path) if not os.path.exists(self._output_dir): os.makedirs(self._output_dir) - def _start_task(self, task): + def _start_task(self, task: Task) -> None: """ record the start of a task for one or more hosts """ uuid = task._uuid @@ -212,11 +213,11 @@ class CallbackModule(CallbackBase): if task_data.name.startswith(self._test_case_prefix) or status == 'failed': task_data.add_host(HostData(host_uuid, host_name, status, result)) - def _build_test_case(self, task_data, host_data): + def _build_test_case(self, task_data: TaskData, host_data: HostData) -> TestCase: """ build a TestCase from the given TaskData and HostData """ name = '[%s] %s: %s' % (host_data.name, task_data.play, task_data.name) - duration = host_data.finish - task_data.start + duration = decimal.Decimal(host_data.finish - task_data.start) if self._task_relative_path and task_data.path: junit_classname = to_text(os.path.relpath(to_bytes(task_data.path), to_bytes(self._task_relative_path))) @@ -242,10 +243,12 @@ class CallbackModule(CallbackBase): test_case = TestCase(name=name, classname=junit_classname, time=duration) + error_summary: ErrorSummary + if host_data.status == 'failed': - if 'exception' in res: - message = res['exception'].strip().split('\n')[-1] - output = res['exception'] + if error_summary := res.get('exception'): + message = error_summary._format() + output = error_summary.formatted_traceback test_case.errors.append(TestError(message=message, output=output)) elif 'msg' in res: message = res['msg'] @@ -261,7 +264,8 @@ class CallbackModule(CallbackBase): return test_case - def _cleanse_string(self, value): + @staticmethod + def _cleanse_string(value): """ convert surrogate escapes to the unicode replacement character to avoid XML encoding errors """ return to_text(to_bytes(value, errors='surrogateescape'), errors='replace') @@ -271,7 +275,7 @@ class CallbackModule(CallbackBase): test_cases = [] for task_uuid, task_data in self._task_data.items(): - if task_data.action in C._ACTION_SETUP and self._include_setup_tasks_in_report == 'false': + if task_data.action in constants._ACTION_SETUP and self._include_setup_tasks_in_report == 'false': continue for host_uuid, host_data in task_data.host_data.items(): @@ -293,16 +297,16 @@ class CallbackModule(CallbackBase): def v2_playbook_on_play_start(self, play): self._play_name = play.get_name() - def v2_runner_on_no_hosts(self, task): + def v2_runner_on_no_hosts(self, task: Task) -> None: self._start_task(task) - def v2_playbook_on_task_start(self, task, is_conditional): + def v2_playbook_on_task_start(self, task: Task, is_conditional: bool) -> None: self._start_task(task) - def v2_playbook_on_cleanup_task_start(self, task): + def v2_playbook_on_cleanup_task_start(self, task: Task) -> None: self._start_task(task) - def v2_playbook_on_handler_task_start(self, task): + def v2_playbook_on_handler_task_start(self, task: Task) -> None: self._start_task(task) def v2_runner_on_failed(self, result, ignore_errors=False): @@ -329,17 +333,17 @@ class TaskData: Data about an individual task. """ - def __init__(self, uuid, name, path, play, action): + def __init__(self, uuid: str, name: str, path: str, play: str, action: str) -> None: self.uuid = uuid self.name = name self.path = path self.play = play self.start = None - self.host_data = {} + self.host_data: dict[str, HostData] = {} self.start = time.time() self.action = action - def add_host(self, host): + def add_host(self, host: HostData) -> None: if host.uuid in self.host_data: if host.status == 'included': # concatenate task include output from multiple items diff --git a/lib/ansible/plugins/callback/minimal.py b/lib/ansible/plugins/callback/minimal.py index 181e90eba9a..3459a5bc5b5 100644 --- a/lib/ansible/plugins/callback/minimal.py +++ b/lib/ansible/plugins/callback/minimal.py @@ -15,6 +15,7 @@ DOCUMENTATION = """ - result_format_callback """ +from ansible.executor.task_result import TaskResult from ansible.plugins.callback import CallbackBase from ansible import constants as C @@ -40,20 +41,18 @@ class CallbackModule(CallbackBase): return buf + "\n" - def v2_runner_on_failed(self, result, ignore_errors=False): - - self._handle_exception(result._result) - self._handle_warnings(result._result) + def v2_runner_on_failed(self, result: TaskResult, ignore_errors: bool = False) -> None: + self._handle_warnings_and_exception(result) if result._task.action in C.MODULE_NO_JSON and 'module_stderr' not in result._result: self._display.display(self._command_generic_msg(result._host.get_name(), result._result, "FAILED"), color=C.COLOR_ERROR) else: self._display.display("%s | FAILED! => %s" % (result._host.get_name(), self._dump_results(result._result, indent=4)), color=C.COLOR_ERROR) - def v2_runner_on_ok(self, result): - self._clean_results(result._result, result._task.action) + def v2_runner_on_ok(self, result: TaskResult) -> None: + self._handle_warnings_and_exception(result) - self._handle_warnings(result._result) + self._clean_results(result._result, result._task.action) if result._result.get('changed', False): color = C.COLOR_CHANGED @@ -67,10 +66,14 @@ class CallbackModule(CallbackBase): else: self._display.display("%s | %s => %s" % (result._host.get_name(), state, self._dump_results(result._result, indent=4)), color=color) - def v2_runner_on_skipped(self, result): + def v2_runner_on_skipped(self, result: TaskResult) -> None: + self._handle_warnings_and_exception(result) + self._display.display("%s | SKIPPED" % (result._host.get_name()), color=C.COLOR_SKIP) - def v2_runner_on_unreachable(self, result): + def v2_runner_on_unreachable(self, result: TaskResult) -> None: + self._handle_warnings_and_exception(result) + self._display.display("%s | UNREACHABLE! => %s" % (result._host.get_name(), self._dump_results(result._result, indent=4)), color=C.COLOR_UNREACHABLE) def v2_on_file_diff(self, result): diff --git a/lib/ansible/plugins/callback/oneline.py b/lib/ansible/plugins/callback/oneline.py index 4ac74d61629..f5292bae859 100644 --- a/lib/ansible/plugins/callback/oneline.py +++ b/lib/ansible/plugins/callback/oneline.py @@ -13,8 +13,9 @@ DOCUMENTATION = """ - This is the output callback used by the C(-o)/C(--one-line) command line option. """ -from ansible.plugins.callback import CallbackBase from ansible import constants as C +from ansible.plugins.callback import CallbackBase +from ansible.template import Templar class CallbackModule(CallbackBase): @@ -28,6 +29,10 @@ class CallbackModule(CallbackBase): CALLBACK_TYPE = 'stdout' CALLBACK_NAME = 'oneline' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._display.deprecated('The oneline callback plugin is deprecated.', version='2.23') + def _command_generic_msg(self, hostname, result, caption): stdout = result.get('stdout', '').replace('\n', '\\n').replace('\r', '\\r') if 'stderr' in result and result['stderr']: @@ -38,12 +43,13 @@ class CallbackModule(CallbackBase): def v2_runner_on_failed(self, result, ignore_errors=False): if 'exception' in result._result: + error_text = Templar().template(result._result['exception']) # transform to a string if self._display.verbosity < 3: # extract just the actual error message from the exception text - error = result._result['exception'].strip().split('\n')[-1] + error = error_text.strip().split('\n')[-1] msg = "An exception occurred during task execution. To see the full traceback, use -vvv. The error was: %s" % error else: - msg = "An exception occurred during task execution. The full traceback is:\n" + result._result['exception'].replace('\n', '') + msg = "An exception occurred during task execution. The full traceback is:\n" + error_text.replace('\n', '') if result._task.action in C.MODULE_NO_JSON and 'module_stderr' not in result._result: self._display.display(self._command_generic_msg(result._host.get_name(), result._result, 'FAILED'), color=C.COLOR_ERROR) diff --git a/lib/ansible/plugins/callback/tree.py b/lib/ansible/plugins/callback/tree.py index 9618f8ec8c7..c67d6cbb817 100644 --- a/lib/ansible/plugins/callback/tree.py +++ b/lib/ansible/plugins/callback/tree.py @@ -45,6 +45,10 @@ class CallbackModule(CallbackBase): CALLBACK_NAME = 'tree' CALLBACK_NEEDS_ENABLED = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._display.deprecated('The tree callback plugin is deprecated.', version='2.23') + def set_options(self, task_keys=None, var_options=None, direct=None): """ override to set self.tree """ diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index 61596a48e41..553235884fd 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -15,6 +15,7 @@ from abc import abstractmethod from functools import wraps from ansible import constants as C +from ansible.errors import AnsibleValueOmittedError from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.playbook.play_context import PlayContext from ansible.plugins import AnsiblePlugin @@ -286,13 +287,19 @@ class ConnectionBase(AnsiblePlugin): } for var_name in C.config.get_plugin_vars('connection', self._load_name): if var_name in variables: - var_options[var_name] = templar.template(variables[var_name]) + try: + var_options[var_name] = templar.template(variables[var_name]) + except AnsibleValueOmittedError: + pass # add extras if plugin supports them if getattr(self, 'allow_extras', False): for var_name in variables: if var_name.startswith(f'ansible_{self.extras_prefix}_') and var_name not in var_options: - var_options['_extras'][var_name] = templar.template(variables[var_name]) + try: + var_options['_extras'][var_name] = templar.template(variables[var_name]) + except AnsibleValueOmittedError: + pass return var_options diff --git a/lib/ansible/plugins/connection/paramiko_ssh.py b/lib/ansible/plugins/connection/paramiko_ssh.py index 971202e2c0b..04117d7de7d 100644 --- a/lib/ansible/plugins/connection/paramiko_ssh.py +++ b/lib/ansible/plugins/connection/paramiko_ssh.py @@ -248,7 +248,7 @@ from ansible.errors import ( AnsibleError, AnsibleFileNotFound, ) -from ansible.module_utils.compat.paramiko import PARAMIKO_IMPORT_ERR, paramiko +from ansible.module_utils.compat.paramiko import _PARAMIKO_IMPORT_ERR as PARAMIKO_IMPORT_ERR, _paramiko as paramiko from ansible.plugins.connection import ConnectionBase from ansible.utils.display import Display from ansible.utils.path import makedirs_safe @@ -327,8 +327,8 @@ class Connection(ConnectionBase): _log_channel: str | None = None def __init__(self, *args, **kwargs): + display.deprecated('The paramiko connection plugin is deprecated.', version='2.21') super().__init__(*args, **kwargs) - display.deprecated('The paramiko connection plugin is deprecated', version='2.21') def _cache_key(self) -> str: return "%s__%s__" % (self.get_option('remote_addr'), self.get_option('remote_user')) @@ -448,19 +448,18 @@ class Connection(ConnectionBase): ) except paramiko.ssh_exception.BadHostKeyException as e: raise AnsibleConnectionFailure('host key mismatch for %s' % e.hostname) - except paramiko.ssh_exception.AuthenticationException as e: - msg = 'Failed to authenticate: {0}'.format(to_text(e)) - raise AnsibleAuthenticationFailure(msg) - except Exception as e: - msg = to_text(e) + except paramiko.ssh_exception.AuthenticationException as ex: + raise AnsibleAuthenticationFailure() from ex + except Exception as ex: + msg = str(ex) if u"PID check failed" in msg: - raise AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible") + raise AnsibleError("paramiko version issue, please upgrade paramiko on the machine running ansible") from ex elif u"Private key file is encrypted" in msg: msg = 'ssh %s@%s:%s : %s\nTo connect as a different user, use -u .' % ( self.get_option('remote_user'), self.get_options('remote_addr'), port, msg) - raise AnsibleConnectionFailure(msg) + raise AnsibleConnectionFailure(msg) from ex else: - raise AnsibleConnectionFailure(msg) + raise AnsibleConnectionFailure(msg) from ex return ssh diff --git a/lib/ansible/plugins/connection/ssh.py b/lib/ansible/plugins/connection/ssh.py index 172cd5e6721..3e854a612b5 100644 --- a/lib/ansible/plugins/connection/ssh.py +++ b/lib/ansible/plugins/connection/ssh.py @@ -969,16 +969,13 @@ class Connection(ConnectionBase): try: fh.write(to_bytes(in_data)) fh.close() - except (OSError, IOError) as e: + except (OSError, IOError) as ex: # The ssh connection may have already terminated at this point, with a more useful error # Only raise AnsibleConnectionFailure if the ssh process is still alive time.sleep(0.001) ssh_process.poll() if getattr(ssh_process, 'returncode', None) is None: - raise AnsibleConnectionFailure( - 'Data could not be sent to remote host "%s". Make sure this host can be reached ' - 'over ssh: %s' % (self.host, to_native(e)), orig_exc=e - ) + raise AnsibleConnectionFailure(f'Data could not be sent to remote host {self.host!r}. Make sure this host can be reached over SSH.') from ex display.debug(u'Sent initial data (%d bytes)' % len(in_data)) diff --git a/lib/ansible/plugins/connection/winrm.py b/lib/ansible/plugins/connection/winrm.py index 1754a0b2dd9..ffa9b6279eb 100644 --- a/lib/ansible/plugins/connection/winrm.py +++ b/lib/ansible/plugins/connection/winrm.py @@ -604,9 +604,7 @@ class Connection(ConnectionBase): self._winrm_write_stdin(command_id, stdin_iterator) except Exception as ex: - display.warning("ERROR DURING WINRM SEND INPUT - attempting to recover: %s %s" - % (type(ex).__name__, to_text(ex))) - display.debug(traceback.format_exc()) + display.error_as_warning("ERROR DURING WINRM SEND INPUT. Attempting to recover.", ex) stdin_push_failed = True # Even on a failure above we try at least once to get the output diff --git a/lib/ansible/plugins/filter/__init__.py b/lib/ansible/plugins/filter/__init__.py index 003711f8b58..c28f8056c9f 100644 --- a/lib/ansible/plugins/filter/__init__.py +++ b/lib/ansible/plugins/filter/__init__.py @@ -3,11 +3,15 @@ from __future__ import annotations -from ansible import constants as C +import typing as t + from ansible.plugins import AnsibleJinja2Plugin class AnsibleJinja2Filter(AnsibleJinja2Plugin): + @property + def plugin_type(self) -> str: + return "filter" - def _no_options(self, *args, **kwargs): + def _no_options(self, *args, **kwargs) -> t.NoReturn: raise NotImplementedError("Jinja2 filter plugins do not support option functions, they use direct arguments instead.") diff --git a/lib/ansible/plugins/filter/bool.yml b/lib/ansible/plugins/filter/bool.yml index beb8b8ddb1f..dcf21077af5 100644 --- a/lib/ansible/plugins/filter/bool.yml +++ b/lib/ansible/plugins/filter/bool.yml @@ -1,13 +1,20 @@ DOCUMENTATION: name: bool version_added: "historical" - short_description: cast into a boolean + short_description: coerce some well-known truthy/falsy values to a boolean description: - - Attempt to cast the input into a boolean (V(True) or V(False)) value. + - Attempt to convert the input value into a boolean (V(True) or V(False)) from a common set of well-known values. + - Valid true values are (V(True), 'yes', 'on', '1', 'true', 1). + - Valid false values are (V(False), 'no', 'off', '0', 'false', 0). + #- An error will result if an invalid value is supplied. + - A deprecation warning will result if an invalid value is supplied. + - For more permissive boolean conversion, consider the P(ansible.builtin.truthy#test) or P(ansible.builtin.falsy#test) tests. + - String comparisons are case-insensitive. + positional: _input options: _input: - description: Data to cast. + description: Data to convert. type: raw required: true @@ -24,5 +31,5 @@ EXAMPLES: | RETURN: _value: - description: The boolean resulting of casting the input expression into a V(True) or V(False) value. + description: The boolean result of coercing the input expression to a V(True) or V(False) value. type: bool diff --git a/lib/ansible/plugins/filter/core.py b/lib/ansible/plugins/filter/core.py index 58c24e4a992..b5b7a145c2c 100644 --- a/lib/ansible/plugins/filter/core.py +++ b/lib/ansible/plugins/filter/core.py @@ -4,6 +4,7 @@ from __future__ import annotations import base64 +import functools import glob import hashlib import json @@ -11,26 +12,30 @@ import ntpath import os.path import re import shlex -import sys import time import uuid import yaml import datetime +import typing as t from collections.abc import Mapping from functools import partial from random import Random, SystemRandom, shuffle -from jinja2.filters import pass_environment +from jinja2.filters import do_map, do_select, do_selectattr, do_reject, do_rejectattr, pass_environment, sync_do_groupby +from jinja2.environment import Environment -from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleFilterTypeError -from ansible.module_utils.six import string_types, integer_types, reraise, text_type +from ansible._internal._templating import _lazy_containers +from ansible.errors import AnsibleFilterError, AnsibleTypeError +from ansible.module_utils.datatag import native_type_name +from ansible.module_utils.common.json import get_encoder, get_decoder +from ansible.module_utils.six import string_types, integer_types, text_type from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.collections import is_sequence from ansible.module_utils.common.yaml import yaml_load, yaml_load_all -from ansible.parsing.ajson import AnsibleJSONEncoder from ansible.parsing.yaml.dumper import AnsibleDumper -from ansible.template import recursive_check_defined +from ansible.plugins import accept_args_markers, accept_lazy_markers +from ansible._internal._templating._jinja_common import MarkerError, UndefinedMarker, validate_arg_type from ansible.utils.display import Display from ansible.utils.encrypt import do_encrypt, PASSLIB_AVAILABLE from ansible.utils.hashing import md5s, checksum_s @@ -42,53 +47,77 @@ display = Display() UUID_NAMESPACE_ANSIBLE = uuid.UUID('361E6D51-FAEC-444A-9079-341386DA8E2E') -def to_yaml(a, *args, **kw): - """Make verbose, human-readable yaml""" - default_flow_style = kw.pop('default_flow_style', None) - try: - transformed = yaml.dump(a, Dumper=AnsibleDumper, allow_unicode=True, default_flow_style=default_flow_style, **kw) - except Exception as e: - raise AnsibleFilterError("to_yaml - %s" % to_native(e), orig_exc=e) - return to_text(transformed) +def to_yaml(a, *_args, default_flow_style: bool | None = None, dump_vault_tags: bool | None = None, **kwargs) -> str: + """Serialize input as terse flow-style YAML.""" + dumper = partial(AnsibleDumper, dump_vault_tags=dump_vault_tags) + return yaml.dump(a, Dumper=dumper, allow_unicode=True, default_flow_style=default_flow_style, **kwargs) -def to_nice_yaml(a, indent=4, *args, **kw): - """Make verbose, human-readable yaml""" - try: - transformed = yaml.dump(a, Dumper=AnsibleDumper, indent=indent, allow_unicode=True, default_flow_style=False, **kw) - except Exception as e: - raise AnsibleFilterError("to_nice_yaml - %s" % to_native(e), orig_exc=e) - return to_text(transformed) + +def to_nice_yaml(a, indent=4, *_args, default_flow_style=False, **kwargs) -> str: + """Serialize input as verbose multi-line YAML.""" + return to_yaml(a, indent=indent, default_flow_style=default_flow_style, **kwargs) + + +def from_json(a, profile: str | None = None, **kwargs) -> t.Any: + """Deserialize JSON with an optional decoder profile.""" + cls = get_decoder(profile or "tagless") + + return json.loads(a, cls=cls, **kwargs) -def to_json(a, *args, **kw): - """ Convert the value to JSON """ +def to_json(a, profile: str | None = None, vault_to_text: t.Any = ..., preprocess_unsafe: t.Any = ..., **kwargs) -> str: + """Serialize as JSON with an optional encoder profile.""" - # defaults for filters - if 'vault_to_text' not in kw: - kw['vault_to_text'] = True - if 'preprocess_unsafe' not in kw: - kw['preprocess_unsafe'] = False + if profile and vault_to_text is not ...: + raise ValueError("Only one of `vault_to_text` or `profile` can be specified.") - return json.dumps(a, cls=AnsibleJSONEncoder, *args, **kw) + if profile and preprocess_unsafe is not ...: + raise ValueError("Only one of `preprocess_unsafe` or `profile` can be specified.") + # deprecated: description='deprecate vault_to_text' core_version='2.23' + # deprecated: description='deprecate preprocess_unsafe' core_version='2.23' -def to_nice_json(a, indent=4, sort_keys=True, *args, **kw): - """Make verbose, human-readable JSON""" + cls = get_encoder(profile or "tagless") + + return json.dumps(a, cls=cls, **kwargs) + + +def to_nice_json(a, indent=4, sort_keys=True, **kwargs): + """Make verbose, human-readable JSON.""" # TODO separators can be potentially exposed to the user as well - kw.pop('separators', None) - return to_json(a, indent=indent, sort_keys=sort_keys, separators=(',', ': '), *args, **kw) + kwargs.pop('separators', None) + return to_json(a, indent=indent, sort_keys=sort_keys, separators=(',', ': '), **kwargs) + +# CAUTION: Do not put non-string values here since they can have unwanted logical equality, such as 1.0 (equal to 1 and True) or 0.0 (equal to 0 and False). +_valid_bool_true = {'yes', 'on', 'true', '1'} +_valid_bool_false = {'no', 'off', 'false', '0'} + + +def to_bool(value: object) -> bool: + """Convert well-known input values to a boolean value.""" + value_to_check: object + if isinstance(value, str): + value_to_check = value.lower() # accept mixed case variants + elif isinstance(value, int): # bool is also an int + value_to_check = str(value).lower() # accept int (0, 1) and bool (True, False) -- not just string versions + else: + value_to_check = value -def to_bool(a): - """ return a bool for the arg """ - if a is None or isinstance(a, bool): - return a - if isinstance(a, string_types): - a = a.lower() - if a in ('yes', 'on', '1', 'true', 1): + if value_to_check in _valid_bool_true: return True - return False + + if value_to_check in _valid_bool_false: + return False + + # if we're still here, the value is unsupported- always fire a deprecation warning + result = value_to_check == 1 # backwards compatibility with the old code which checked: value in ('yes', 'on', '1', 'true', 1) + + # NB: update the doc string to reflect reality once this fallback is removed + display.deprecated(f'The `bool` filter coerced invalid value {value!r} ({native_type_name(value)}) to {result!r}.', version='2.23') + + return result def to_datetime(string, format="%Y-%m-%d %H:%M:%S"): @@ -289,12 +318,7 @@ def get_encrypted_password(password, hashtype='sha512', salt=None, salt_size=Non if PASSLIB_AVAILABLE and hashtype not in passlib_mapping and hashtype not in passlib_mapping.values(): raise AnsibleFilterError(f"{hashtype} is not in the list of supported passlib algorithms: {', '.join(passlib_mapping)}") - try: - return do_encrypt(password, hashtype, salt=salt, salt_size=salt_size, rounds=rounds, ident=ident) - except AnsibleError as e: - reraise(AnsibleFilterError, AnsibleFilterError(to_native(e), orig_exc=e), sys.exc_info()[2]) - except Exception as e: - raise AnsibleFilterError(f"Failed to encrypt the password due to: {e}") + return do_encrypt(password, hashtype, salt=salt, salt_size=salt_size, rounds=rounds, ident=ident) def to_uuid(string, namespace=UUID_NAMESPACE_ANSIBLE): @@ -308,19 +332,21 @@ def to_uuid(string, namespace=UUID_NAMESPACE_ANSIBLE): return to_text(uuid.uuid5(uuid_namespace, to_native(string, errors='surrogate_or_strict'))) -def mandatory(a, msg=None): +@accept_args_markers +def mandatory(a: object, msg: str | None = None) -> object: """Make a variable mandatory.""" - from jinja2.runtime import Undefined + # DTFIX-RELEASE: deprecate this filter; there are much better ways via undef, etc... + # also remember to remove unit test checking for _undefined_name + if isinstance(a, UndefinedMarker): + if msg is not None: + raise AnsibleFilterError(to_text(msg)) - if isinstance(a, Undefined): if a._undefined_name is not None: - name = "'%s' " % to_text(a._undefined_name) + name = f'{to_text(a._undefined_name)!r} ' else: name = '' - if msg is not None: - raise AnsibleFilterError(to_native(msg)) - raise AnsibleFilterError("Mandatory variable %s not defined." % name) + raise AnsibleFilterError(f"Mandatory variable {name}not defined.") return a @@ -334,9 +360,6 @@ def combine(*terms, **kwargs): # allow the user to do `[dict1, dict2, ...] | combine` dictionaries = flatten(terms, levels=1) - # recursively check that every elements are defined (for jinja2) - recursive_check_defined(dictionaries) - if not dictionaries: return {} @@ -442,7 +465,7 @@ def comment(text, style='plain', **kw): @pass_environment -def extract(environment, item, container, morekeys=None): +def extract(environment: Environment, item, container, morekeys=None): if morekeys is None: keys = [item] elif isinstance(morekeys, list): @@ -451,8 +474,12 @@ def extract(environment, item, container, morekeys=None): keys = [item, morekeys] value = container + for key in keys: - value = environment.getitem(value, key) + try: + value = environment.getitem(value, key) + except MarkerError as ex: + value = ex.source return value @@ -513,7 +540,7 @@ def subelements(obj, subelements, skip_missing=False): elif isinstance(subelements, string_types): subelement_list = subelements.split('.') else: - raise AnsibleFilterTypeError('subelements must be a list or a string') + raise AnsibleTypeError('subelements must be a list or a string') results = [] @@ -527,10 +554,10 @@ def subelements(obj, subelements, skip_missing=False): values = [] break raise AnsibleFilterError("could not find %r key in iterated item %r" % (subelement, values)) - except TypeError: - raise AnsibleFilterTypeError("the key %s should point to a dictionary, got '%s'" % (subelement, values)) + except TypeError as ex: + raise AnsibleTypeError("the key %s should point to a dictionary, got '%s'" % (subelement, values)) from ex if not isinstance(values, list): - raise AnsibleFilterTypeError("the key %r should point to a list, got %r" % (subelement, values)) + raise AnsibleTypeError("the key %r should point to a list, got %r" % (subelement, values)) for value in values: results.append((element, value)) @@ -543,7 +570,7 @@ def dict_to_list_of_dict_key_value_elements(mydict, key_name='key', value_name=' with each having a 'key' and 'value' keys that correspond to the keys and values of the original """ if not isinstance(mydict, Mapping): - raise AnsibleFilterTypeError("dict2items requires a dictionary, got %s instead." % type(mydict)) + raise AnsibleTypeError("dict2items requires a dictionary, got %s instead." % type(mydict)) ret = [] for key in mydict: @@ -556,17 +583,17 @@ def list_of_dict_key_value_elements_to_dict(mylist, key_name='key', value_name=' effectively as the reverse of dict2items """ if not is_sequence(mylist): - raise AnsibleFilterTypeError("items2dict requires a list, got %s instead." % type(mylist)) + raise AnsibleTypeError("items2dict requires a list, got %s instead." % type(mylist)) try: return dict((item[key_name], item[value_name]) for item in mylist) except KeyError: - raise AnsibleFilterTypeError( + raise AnsibleTypeError( "items2dict requires each dictionary in the list to contain the keys '%s' and '%s', got %s instead." % (key_name, value_name, mylist) ) except TypeError: - raise AnsibleFilterTypeError("items2dict requires a list of dictionaries, got %s instead." % mylist) + raise AnsibleTypeError("items2dict requires a list of dictionaries, got %s instead." % mylist) def path_join(paths): @@ -576,7 +603,7 @@ def path_join(paths): return os.path.join(paths) if is_sequence(paths): return os.path.join(*paths) - raise AnsibleFilterTypeError("|path_join expects string or sequence, got %s instead." % type(paths)) + raise AnsibleTypeError("|path_join expects string or sequence, got %s instead." % type(paths)) def commonpath(paths): @@ -589,11 +616,90 @@ def commonpath(paths): :rtype: str """ if not is_sequence(paths): - raise AnsibleFilterTypeError("|commonpath expects sequence, got %s instead." % type(paths)) + raise AnsibleTypeError("|commonpath expects sequence, got %s instead." % type(paths)) return os.path.commonpath(paths) +class GroupTuple(t.NamedTuple): + """ + Custom named tuple for the groupby filter with a public interface; silently ignored by unknown type checks. + This matches the internal implementation of the _GroupTuple returned by Jinja's built-in groupby filter. + """ + + grouper: t.Any + list: list[t.Any] + + def __repr__(self) -> str: + return tuple.__repr__(self) + + +_lazy_containers.register_known_types(GroupTuple) + + +@pass_environment +def _cleansed_groupby(*args, **kwargs): + res = sync_do_groupby(*args, **kwargs) + res = [GroupTuple(grouper=g.grouper, list=g.list) for g in res] + + return res + +# DTFIX-RELEASE: make these dumb wrappers more dynamic + + +@accept_args_markers +def ansible_default( + value: t.Any, + default_value: t.Any = '', + boolean: bool = False, +) -> t.Any: + """Updated `default` filter that only coalesces classic undefined objects; other Undefined-derived types (eg, ErrorMarker) pass through.""" + validate_arg_type('boolean', boolean, bool) + + if isinstance(value, UndefinedMarker): + return default_value + + if boolean and not value: + return default_value + + return value + + +@accept_lazy_markers +@functools.wraps(do_map) +def wrapped_map(*args, **kwargs) -> t.Any: + return do_map(*args, **kwargs) + + +@accept_lazy_markers +@functools.wraps(do_select) +def wrapped_select(*args, **kwargs) -> t.Any: + return do_select(*args, **kwargs) + + +@accept_lazy_markers +@functools.wraps(do_selectattr) +def wrapped_selectattr(*args, **kwargs) -> t.Any: + return do_selectattr(*args, **kwargs) + + +@accept_lazy_markers +@functools.wraps(do_reject) +def wrapped_reject(*args, **kwargs) -> t.Any: + return do_reject(*args, **kwargs) + + +@accept_lazy_markers +@functools.wraps(do_rejectattr) +def wrapped_rejectattr(*args, **kwargs) -> t.Any: + return do_rejectattr(*args, **kwargs) + + +@accept_args_markers +def type_debug(obj: object) -> str: + return native_type_name(obj) + + class FilterModule(object): """ Ansible core jinja2 filters """ @@ -609,7 +715,7 @@ class FilterModule(object): # json 'to_json': to_json, 'to_nice_json': to_nice_json, - 'from_json': json.loads, + 'from_json': from_json, # yaml 'to_yaml': to_yaml, @@ -676,7 +782,7 @@ class FilterModule(object): 'comment': comment, # debug - 'type_debug': lambda o: o.__class__.__name__, + 'type_debug': type_debug, # Data structures 'combine': combine, @@ -686,4 +792,18 @@ class FilterModule(object): 'items2dict': list_of_dict_key_value_elements_to_dict, 'subelements': subelements, 'split': partial(unicode_wrap, text_type.split), + # FDI038 - replace this with a standard type compat shim + 'groupby': _cleansed_groupby, + + # Jinja builtins that need special arg handling + # DTFIX-RELEASE: document these now that they're overridden, or hide them so they don't show up as undocumented + 'd': ansible_default, # replaces the implementation instead of wrapping it + 'default': ansible_default, # replaces the implementation instead of wrapping it + 'map': wrapped_map, + 'select': wrapped_select, + 'selectattr': wrapped_selectattr, + 'reject': wrapped_reject, + 'rejectattr': wrapped_rejectattr, } + +# DTFIX-RELEASE: document protomatter plugins, or hide them from ansible-doc/galaxy (not related to this code, but needed some place to put this comment) diff --git a/lib/ansible/plugins/filter/encryption.py b/lib/ansible/plugins/filter/encryption.py index 580e07bea20..42fcac3e0c2 100644 --- a/lib/ansible/plugins/filter/encryption.py +++ b/lib/ansible/plugins/filter/encryption.py @@ -2,80 +2,80 @@ from __future__ import annotations -from jinja2.runtime import Undefined -from jinja2.exceptions import UndefinedError - -from ansible.errors import AnsibleFilterError, AnsibleFilterTypeError +from ansible.errors import AnsibleError from ansible.module_utils.common.text.converters import to_native, to_bytes -from ansible.module_utils.six import string_types, binary_type -from ansible.parsing.yaml.objects import AnsibleVaultEncryptedUnicode -from ansible.parsing.vault import is_encrypted, VaultSecret, VaultLib +from ansible.plugins import accept_args_markers +from ansible._internal._templating._jinja_common import get_first_marker_arg, VaultExceptionMarker +from ansible._internal._datatag._tags import VaultedValue +from ansible.parsing.vault import is_encrypted, VaultSecret, VaultLib, VaultHelper from ansible.utils.display import Display display = Display() def do_vault(data, secret, salt=None, vault_id='filter_default', wrap_object=False, vaultid=None): + if not isinstance(secret, (str, bytes)): + raise TypeError(f"Secret passed is required to be a string, instead we got {type(secret)}.") - if not isinstance(secret, (string_types, binary_type, Undefined)): - raise AnsibleFilterTypeError("Secret passed is required to be a string, instead we got: %s" % type(secret)) - - if not isinstance(data, (string_types, binary_type, Undefined)): - raise AnsibleFilterTypeError("Can only vault strings, instead we got: %s" % type(data)) + if not isinstance(data, (str, bytes)): + raise TypeError(f"Can only vault strings, instead we got {type(data)}.") if vaultid is not None: display.deprecated("Use of undocumented 'vaultid', use 'vault_id' instead", version='2.20') + if vault_id == 'filter_default': vault_id = vaultid else: display.warning("Ignoring vaultid as vault_id is already set.") - vault = '' vs = VaultSecret(to_bytes(secret)) vl = VaultLib() try: vault = vl.encrypt(to_bytes(data), vs, vault_id, salt) - except UndefinedError: - raise - except Exception as e: - raise AnsibleFilterError("Unable to encrypt: %s" % to_native(e), orig_exc=e) + except Exception as ex: + raise AnsibleError("Unable to encrypt.") from ex if wrap_object: - vault = AnsibleVaultEncryptedUnicode(vault) + vault = VaultedValue(ciphertext=str(vault)).tag(secret) else: vault = to_native(vault) return vault +@accept_args_markers def do_unvault(vault, secret, vault_id='filter_default', vaultid=None): + if isinstance(vault, VaultExceptionMarker): + vault = vault._disarm() + + if (first_marker := get_first_marker_arg((vault, secret, vault_id, vaultid), {})) is not None: + return first_marker - if not isinstance(secret, (string_types, binary_type, Undefined)): - raise AnsibleFilterTypeError("Secret passed is required to be as string, instead we got: %s" % type(secret)) + if not isinstance(secret, (str, bytes)): + raise TypeError(f"Secret passed is required to be as string, instead we got {type(secret)}.") - if not isinstance(vault, (string_types, binary_type, AnsibleVaultEncryptedUnicode, Undefined)): - raise AnsibleFilterTypeError("Vault should be in the form of a string, instead we got: %s" % type(vault)) + if not isinstance(vault, (str, bytes)): + raise TypeError(f"Vault should be in the form of a string, instead we got {type(vault)}.") if vaultid is not None: display.deprecated("Use of undocumented 'vaultid', use 'vault_id' instead", version='2.20') + if vault_id == 'filter_default': vault_id = vaultid else: display.warning("Ignoring vaultid as vault_id is already set.") - data = '' vs = VaultSecret(to_bytes(secret)) vl = VaultLib([(vault_id, vs)]) - if isinstance(vault, AnsibleVaultEncryptedUnicode): - vault.vault = vl - data = vault.data - elif is_encrypted(vault): + + if ciphertext := VaultHelper.get_ciphertext(vault, with_tags=True): + vault = ciphertext + + if is_encrypted(vault): try: data = vl.decrypt(vault) - except UndefinedError: - raise - except Exception as e: - raise AnsibleFilterError("Unable to decrypt: %s" % to_native(e), orig_exc=e) + except Exception as ex: + raise AnsibleError("Unable to decrypt.") from ex else: data = vault diff --git a/lib/ansible/plugins/filter/mathstuff.py b/lib/ansible/plugins/filter/mathstuff.py index d80eb3347c1..a9247a2c984 100644 --- a/lib/ansible/plugins/filter/mathstuff.py +++ b/lib/ansible/plugins/filter/mathstuff.py @@ -27,10 +27,9 @@ from collections.abc import Mapping, Iterable from jinja2.filters import pass_environment -from ansible.errors import AnsibleFilterError, AnsibleFilterTypeError +from ansible.errors import AnsibleError from ansible.module_utils.common.text import formatters from ansible.module_utils.six import binary_type, text_type -from ansible.module_utils.common.text.converters import to_native, to_text from ansible.utils.display import Display try: @@ -48,10 +47,11 @@ display = Display() # explicitly set and cannot be handle (by Jinja2 w/o 'unique' or fallback version) def unique(environment, a, case_sensitive=None, attribute=None): - def _do_fail(e): + def _do_fail(ex): if case_sensitive is False or attribute: - raise AnsibleFilterError("Jinja2's unique filter failed and we cannot fall back to Ansible's version " - "as it does not support the parameters supplied", orig_exc=e) + raise AnsibleError( + "Jinja2's unique filter failed and we cannot fall back to Ansible's version as it does not support the parameters supplied." + ) from ex error = e = None try: @@ -63,14 +63,14 @@ def unique(environment, a, case_sensitive=None, attribute=None): except Exception as e: error = e _do_fail(e) - display.warning('Falling back to Ansible unique filter as Jinja2 one failed: %s' % to_text(e)) + display.error_as_warning('Falling back to Ansible unique filter as Jinja2 one failed.', e) if not HAS_UNIQUE or error: # handle Jinja2 specific attributes when using Ansible's version if case_sensitive is False or attribute: - raise AnsibleFilterError("Ansible's unique filter does not support case_sensitive=False nor attribute parameters, " - "you need a newer version of Jinja2 that provides their version of the filter.") + raise AnsibleError("Ansible's unique filter does not support case_sensitive=False nor attribute parameters, " + "you need a newer version of Jinja2 that provides their version of the filter.") c = [] for x in a: @@ -123,15 +123,15 @@ def logarithm(x, base=math.e): return math.log10(x) else: return math.log(x, base) - except TypeError as e: - raise AnsibleFilterTypeError('log() can only be used on numbers: %s' % to_native(e)) + except TypeError as ex: + raise AnsibleError('log() can only be used on numbers') from ex def power(x, y): try: return math.pow(x, y) - except TypeError as e: - raise AnsibleFilterTypeError('pow() can only be used on numbers: %s' % to_native(e)) + except TypeError as ex: + raise AnsibleError('pow() can only be used on numbers') from ex def inversepower(x, base=2): @@ -140,28 +140,28 @@ def inversepower(x, base=2): return math.sqrt(x) else: return math.pow(x, 1.0 / float(base)) - except (ValueError, TypeError) as e: - raise AnsibleFilterTypeError('root() can only be used on numbers: %s' % to_native(e)) + except (ValueError, TypeError) as ex: + raise AnsibleError('root() can only be used on numbers') from ex def human_readable(size, isbits=False, unit=None): """ Return a human-readable string """ try: return formatters.bytes_to_human(size, isbits, unit) - except TypeError as e: - raise AnsibleFilterTypeError("human_readable() failed on bad input: %s" % to_native(e)) - except Exception: - raise AnsibleFilterError("human_readable() can't interpret following string: %s" % size) + except TypeError as ex: + raise AnsibleError("human_readable() failed on bad input") from ex + except Exception as ex: + raise AnsibleError("human_readable() can't interpret the input") from ex def human_to_bytes(size, default_unit=None, isbits=False): """ Return bytes count from a human-readable string """ try: return formatters.human_to_bytes(size, default_unit, isbits) - except TypeError as e: - raise AnsibleFilterTypeError("human_to_bytes() failed on bad input: %s" % to_native(e)) - except Exception: - raise AnsibleFilterError("human_to_bytes() can't interpret following string: %s" % size) + except TypeError as ex: + raise AnsibleError("human_to_bytes() failed on bad input") from ex + except Exception as ex: + raise AnsibleError("human_to_bytes() can't interpret the input") from ex def rekey_on_member(data, key, duplicates='error'): @@ -174,38 +174,31 @@ def rekey_on_member(data, key, duplicates='error'): value would be duplicated or to overwrite previous entries if that's the case. """ if duplicates not in ('error', 'overwrite'): - raise AnsibleFilterError("duplicates parameter to rekey_on_member has unknown value: {0}".format(duplicates)) + raise AnsibleError(f"duplicates parameter to rekey_on_member has unknown value {duplicates!r}") new_obj = {} - # Ensure the positional args are defined - raise jinja2.exceptions.UndefinedError if not - bool(data) and bool(key) - if isinstance(data, Mapping): iterate_over = data.values() elif isinstance(data, Iterable) and not isinstance(data, (text_type, binary_type)): iterate_over = data else: - raise AnsibleFilterTypeError("Type is not a valid list, set, or dict") + raise AnsibleError("Type is not a valid list, set, or dict") for item in iterate_over: if not isinstance(item, Mapping): - raise AnsibleFilterTypeError("List item is not a valid dict") + raise AnsibleError("List item is not a valid dict") try: key_elem = item[key] except KeyError: - raise AnsibleFilterError("Key {0} was not found".format(key)) - except TypeError as e: - raise AnsibleFilterTypeError(to_native(e)) - except Exception as e: - raise AnsibleFilterError(to_native(e)) + raise AnsibleError(f"Key {key!r} was not found.", obj=item) from None # Note: if new_obj[key_elem] exists it will always be a non-empty dict (it will at # minimum contain {key: key_elem} if new_obj.get(key_elem, None): if duplicates == 'error': - raise AnsibleFilterError("Key {0} is not unique, cannot correctly turn into dict".format(key_elem)) + raise AnsibleError(f"Key {key_elem!r} is not unique, cannot convert to dict.") elif duplicates == 'overwrite': new_obj[key_elem] = item else: diff --git a/lib/ansible/plugins/filter/regex_search.yml b/lib/ansible/plugins/filter/regex_search.yml index e0eda9ccc0d..16a06b8076f 100644 --- a/lib/ansible/plugins/filter/regex_search.yml +++ b/lib/ansible/plugins/filter/regex_search.yml @@ -8,9 +8,6 @@ DOCUMENTATION: - Maps to Python's C(re.search). - 'The substring matched by the group is accessible via the symbolic group name or the ``\{number}`` special sequence. See examples section.' - - The return for no match will be C(None) in most cases, depending on whether it is used with other filters/tests or not. - It also depends on the Jinja version used and whether native is enabled. - - "For a more complete explanation see U(https://docs.ansible.com/ansible-core/devel/reference_appendices/faq.html#why-does-the-regex-search-filter-return-none-instead-of-an-empty-string)." positional: _input, _regex options: _input: @@ -55,5 +52,5 @@ EXAMPLES: | RETURN: _value: - description: Matched string or if no match a C(None) or an empty string (see notes) + description: Matched string or C(None) if no match. type: str diff --git a/lib/ansible/plugins/filter/to_nice_yaml.yml b/lib/ansible/plugins/filter/to_nice_yaml.yml index faf4c837928..664d7ce58c0 100644 --- a/lib/ansible/plugins/filter/to_nice_yaml.yml +++ b/lib/ansible/plugins/filter/to_nice_yaml.yml @@ -20,10 +20,6 @@ DOCUMENTATION: description: Affects sorting of dictionary keys. default: True type: bool - #allow_unicode: - # description: - # type: bool - # default: true #default_style=None, canonical=None, width=None, line_break=None, encoding=None, explicit_start=None, explicit_end=None, version=None, tags=None notes: - More options may be available, see L(PyYAML documentation, https://pyyaml.org/wiki/PyYAMLDocumentation) for details. diff --git a/lib/ansible/plugins/filter/to_yaml.yml b/lib/ansible/plugins/filter/to_yaml.yml index 224cf129f31..ba71f7ae9c3 100644 --- a/lib/ansible/plugins/filter/to_yaml.yml +++ b/lib/ansible/plugins/filter/to_yaml.yml @@ -24,10 +24,6 @@ DOCUMENTATION: - More options may be available, see L(PyYAML documentation, https://pyyaml.org/wiki/PyYAMLDocumentation) for details. # TODO: find docs for these - #allow_unicode: - # description: - # type: bool - # default: true #default_flow_style #default_style #canonical=None, diff --git a/lib/ansible/plugins/filter/unvault.yml b/lib/ansible/plugins/filter/unvault.yml index 82747a6fce3..3512fb08692 100644 --- a/lib/ansible/plugins/filter/unvault.yml +++ b/lib/ansible/plugins/filter/unvault.yml @@ -8,7 +8,7 @@ DOCUMENTATION: positional: secret options: _input: - description: Vault string, or an C(AnsibleVaultEncryptedUnicode) string object. + description: Vault string. type: string required: true secret: diff --git a/lib/ansible/plugins/filter/urlsplit.py b/lib/ansible/plugins/filter/urlsplit.py index 3b1d35f6b59..8f777953a63 100644 --- a/lib/ansible/plugins/filter/urlsplit.py +++ b/lib/ansible/plugins/filter/urlsplit.py @@ -58,7 +58,6 @@ RETURN = r""" from urllib.parse import urlsplit -from ansible.errors import AnsibleFilterError from ansible.utils import helpers @@ -70,7 +69,7 @@ def split_url(value, query='', alias='urlsplit'): # If no option is supplied, return the entire dictionary. if query: if query not in results: - raise AnsibleFilterError(alias + ': unknown URL component: %s' % query) + raise ValueError(alias + ': unknown URL component: %s' % query) return results[query] else: return results diff --git a/lib/ansible/plugins/filter/vault.yml b/lib/ansible/plugins/filter/vault.yml index d5dbcf0f331..43e2801cf70 100644 --- a/lib/ansible/plugins/filter/vault.yml +++ b/lib/ansible/plugins/filter/vault.yml @@ -26,7 +26,7 @@ DOCUMENTATION: default: 'filter_default' wrap_object: description: - - This toggle can force the return of an C(AnsibleVaultEncryptedUnicode) string object, when V(False), you get a simple string. + - This toggle can force the return of a C(VaultedValue)-tagged string object, when V(False), you get a simple string. - Mostly useful when combining with the C(to_yaml) filter to output the 'inline vault' format. type: bool default: False @@ -49,5 +49,5 @@ EXAMPLES: | RETURN: _value: - description: The vault string that contains the secret data (or C(AnsibleVaultEncryptedUnicode) string object). + description: The vault string that contains the secret data (or C(VaultedValue)-tagged string object). type: string diff --git a/lib/ansible/plugins/inventory/__init__.py b/lib/ansible/plugins/inventory/__init__.py index 324234cb7ec..cdf1eb608be 100644 --- a/lib/ansible/plugins/inventory/__init__.py +++ b/lib/ansible/plugins/inventory/__init__.py @@ -17,24 +17,30 @@ from __future__ import annotations +import functools import hashlib import os import string +import typing as t from collections.abc import Mapping -from ansible.errors import AnsibleError, AnsibleParserError +from ansible import template as _template +from ansible.errors import AnsibleError, AnsibleParserError, AnsibleValueOmittedError from ansible.inventory.group import to_safe_group_name as original_safe +from ansible.module_utils._internal import _plugin_exec_context from ansible.parsing.utils.addresses import parse_address -from ansible.plugins import AnsiblePlugin -from ansible.plugins.cache import CachePluginAdjudicator as CacheObject +from ansible.parsing.dataloader import DataLoader +from ansible.plugins import AnsiblePlugin, _ConfigurablePlugin +from ansible.plugins.cache import CachePluginAdjudicator from ansible.module_utils.common.text.converters import to_bytes, to_native -from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.six import string_types -from ansible.template import Templar from ansible.utils.display import Display from ansible.utils.vars import combine_vars, load_extra_vars +if t.TYPE_CHECKING: + from ansible.inventory.data import InventoryData + display = Display() @@ -127,8 +133,11 @@ def expand_hostname_range(line=None): def get_cache_plugin(plugin_name, **kwargs): + if not plugin_name: + raise AnsibleError("A cache plugin must be configured to use inventory caching.") + try: - cache = CacheObject(plugin_name, **kwargs) + cache = CachePluginAdjudicator(plugin_name, **kwargs) except AnsibleError as e: if 'fact_caching_connection' in to_native(e): raise AnsibleError("error, '%s' inventory cache plugin requires the one of the following to be set " @@ -136,17 +145,22 @@ def get_cache_plugin(plugin_name, **kwargs): "[inventory]: cache_connection;\nEnvironment:\nANSIBLE_INVENTORY_CACHE_CONNECTION,\n" "ANSIBLE_CACHE_PLUGIN_CONNECTION." % plugin_name) else: - raise e + raise - if plugin_name != 'memory' and kwargs and not getattr(cache._plugin, '_options', None): + if cache._plugin.ansible_name != 'ansible.builtin.memory' and kwargs and not getattr(cache._plugin, '_options', None): raise AnsibleError('Unable to use cache plugin {0} for inventory. Cache options were provided but may not reconcile ' 'correctly unless set via set_options. Refer to the porting guide if the plugin derives user settings ' 'from ansible.constants.'.format(plugin_name)) return cache -class BaseInventoryPlugin(AnsiblePlugin): - """ Parses an Inventory Source""" +class _BaseInventoryPlugin(AnsiblePlugin): + """ + Internal base implementation for inventory plugins. + + Do not inherit from this directly, use one of its public subclasses instead. + Used to introduce an extra layer in the class hierarchy to allow Constructed to subclass this while remaining a mixin for existing inventory plugins. + """ TYPE = 'generator' @@ -156,16 +170,26 @@ class BaseInventoryPlugin(AnsiblePlugin): # it by default. _sanitize_group_name = staticmethod(to_safe_group_name) - def __init__(self): + def __init__(self) -> None: - super(BaseInventoryPlugin, self).__init__() + super().__init__() self._options = {} - self.inventory = None self.display = display - self._vars = {} - def parse(self, inventory, loader, path, cache=True): + # These attributes are set by the parse() method on this (base) class. + self.loader: DataLoader | None = None + self.inventory: InventoryData | None = None + self._vars: dict[str, t.Any] | None = None + + trusted_by_default: bool = False + """Inventory plugins that only source templates from trusted sources can set this True to have trust automatically applied to all templates.""" + + @functools.cached_property + def templar(self) -> _template.Templar: + return _template.Templar(loader=self.loader) + + def parse(self, inventory: InventoryData, loader: DataLoader, path: str, cache: bool = True) -> None: """ Populates inventory from the given data. Raises an error on any parse failure :arg inventory: a copy of the previously accumulated inventory data, to be updated with any new data this plugin provides. @@ -178,10 +202,8 @@ class BaseInventoryPlugin(AnsiblePlugin): :arg cache: a boolean that indicates if the plugin should use the cache or not you can ignore if this plugin does not implement caching. """ - self.loader = loader self.inventory = inventory - self.templar = Templar(loader=loader) self._vars = load_extra_vars(loader) def verify_file(self, path): @@ -214,11 +236,10 @@ class BaseInventoryPlugin(AnsiblePlugin): :arg path: path to common yaml format config file for this plugin """ - config = {} try: # avoid loader cache so meta: refresh_inventory can pick up config changes # if we read more than once, fs cache should be good enough - config = self.loader.load_from_file(path, cache='none') + config = self.loader.load_from_file(path, cache='none', trusted_as_template=True) except Exception as e: raise AnsibleParserError(to_native(e)) @@ -279,7 +300,11 @@ class BaseInventoryPlugin(AnsiblePlugin): return (hostnames, port) -class BaseFileInventoryPlugin(BaseInventoryPlugin): +class BaseInventoryPlugin(_BaseInventoryPlugin): + """ Parses an Inventory Source """ + + +class BaseFileInventoryPlugin(_BaseInventoryPlugin): """ Parses a File based Inventory Source""" TYPE = 'storage' @@ -289,51 +314,44 @@ class BaseFileInventoryPlugin(BaseInventoryPlugin): super(BaseFileInventoryPlugin, self).__init__() -class Cacheable(object): +class Cacheable(_plugin_exec_context.HasPluginInfo, _ConfigurablePlugin): + """Mixin for inventory plugins which support caching.""" - _cache = CacheObject() + _cache: CachePluginAdjudicator @property - def cache(self): + def cache(self) -> CachePluginAdjudicator: return self._cache - def load_cache_plugin(self): + def load_cache_plugin(self) -> None: plugin_name = self.get_option('cache_plugin') cache_option_keys = [('_uri', 'cache_connection'), ('_timeout', 'cache_timeout'), ('_prefix', 'cache_prefix')] cache_options = dict((opt[0], self.get_option(opt[1])) for opt in cache_option_keys if self.get_option(opt[1]) is not None) self._cache = get_cache_plugin(plugin_name, **cache_options) - def get_cache_key(self, path): - return "{0}_{1}".format(self.NAME, self._get_cache_prefix(path)) - - def _get_cache_prefix(self, path): - """ create predictable unique prefix for plugin/inventory """ - - m = hashlib.sha1() - m.update(to_bytes(self.NAME, errors='surrogate_or_strict')) - d1 = m.hexdigest() - - n = hashlib.sha1() - n.update(to_bytes(path, errors='surrogate_or_strict')) - d2 = n.hexdigest() + def get_cache_key(self, path: str) -> str: + return f'{self.ansible_name}_{self._get_cache_prefix(path)}' - return 's_'.join([d1[:5], d2[:5]]) + def _get_cache_prefix(self, path: str) -> str: + """Return a predictable unique key based on the given path.""" + # DTFIX-RELEASE: choose a better hashing approach + return 'k' + hashlib.sha256(f'{self.ansible_name}{path}'.encode(), usedforsecurity=False).hexdigest()[:6] - def clear_cache(self): - self._cache.flush() + def clear_cache(self) -> None: + self._cache.clear() - def update_cache_if_changed(self): + def update_cache_if_changed(self) -> None: self._cache.update_cache_if_changed() - def set_cache_plugin(self): + def set_cache_plugin(self) -> None: self._cache.set_cache() -class Constructable(object): - - def _compose(self, template, variables, disable_lookups=True): +class Constructable(_BaseInventoryPlugin): + def _compose(self, template, variables, disable_lookups=...): """ helper method for plugins to compose variables for Ansible based on jinja2 expression and inventory vars""" - t = self.templar + if disable_lookups is not ...: + self.display.deprecated("The disable_lookups arg has no effect.", version="2.23") try: use_extra = self.get_option('use_extra_vars') @@ -341,12 +359,11 @@ class Constructable(object): use_extra = False if use_extra: - t.available_variables = combine_vars(variables, self._vars) + self.templar.available_variables = combine_vars(variables, self._vars) else: - t.available_variables = variables + self.templar.available_variables = variables - return t.template('%s%s%s' % (t.environment.variable_start_string, template, t.environment.variable_end_string), - disable_lookups=disable_lookups) + return self.templar.evaluate_expression(template) def _set_composite_vars(self, compose, variables, host, strict=False): """ loops over compose entries to create vars for hosts """ @@ -368,10 +385,10 @@ class Constructable(object): variables = combine_vars(variables, self.inventory.get_host(host).get_vars()) self.templar.available_variables = variables for group_name in groups: - conditional = "{%% if %s %%} True {%% else %%} False {%% endif %%}" % groups[group_name] + conditional = groups[group_name] group_name = self._sanitize_group_name(group_name) try: - result = boolean(self.templar.template(conditional)) + result = self.templar.evaluate_conditional(conditional) except Exception as e: if strict: raise AnsibleParserError("Could not add host %s to group %s: %s" % (host, group_name, to_native(e))) @@ -405,13 +422,16 @@ class Constructable(object): prefix = keyed.get('prefix', '') sep = keyed.get('separator', '_') raw_parent_name = keyed.get('parent_group', None) - if raw_parent_name: - try: - raw_parent_name = self.templar.template(raw_parent_name) - except AnsibleError as e: - if strict: - raise AnsibleParserError("Could not generate parent group %s for group %s: %s" % (raw_parent_name, key, to_native(e))) - continue + + try: + raw_parent_name = self.templar.template(raw_parent_name) + except AnsibleValueOmittedError: + raw_parent_name = None + except Exception as ex: + if strict: + raise AnsibleParserError(f'Could not generate parent group {raw_parent_name!r} for group {key!r}: {ex}') from ex + + continue new_raw_group_names = [] if isinstance(key, string_types): diff --git a/lib/ansible/plugins/inventory/advanced_host_list.py b/lib/ansible/plugins/inventory/advanced_host_list.py index 7a9646ef9ac..7f03558d573 100644 --- a/lib/ansible/plugins/inventory/advanced_host_list.py +++ b/lib/ansible/plugins/inventory/advanced_host_list.py @@ -31,6 +31,8 @@ class InventoryModule(BaseInventoryPlugin): NAME = 'advanced_host_list' + # advanced_host_list does not set vars, so needs no special trust assistance from the inventory API + def verify_file(self, host_list): valid = False diff --git a/lib/ansible/plugins/inventory/auto.py b/lib/ansible/plugins/inventory/auto.py index 81f0352911a..9bfd10f7695 100644 --- a/lib/ansible/plugins/inventory/auto.py +++ b/lib/ansible/plugins/inventory/auto.py @@ -30,6 +30,8 @@ class InventoryModule(BaseInventoryPlugin): NAME = 'auto' + # no need to set trusted_by_default, since the consumers of this value will always consult the real plugin substituted during our parse() + def verify_file(self, path): if not path.endswith('.yml') and not path.endswith('.yaml'): return False @@ -55,6 +57,11 @@ class InventoryModule(BaseInventoryPlugin): raise AnsibleParserError("inventory source '{0}' could not be verified by inventory plugin '{1}'".format(path, plugin_name)) self.display.v("Using inventory plugin '{0}' to process inventory source '{1}'".format(plugin._load_name, path)) + + # unfortunate magic to swap the real plugin type we're proxying here into the inventory data API wrapper, so the wrapper can make the right compat + # decisions based on the metadata the real plugin provides instead of our metadata + inventory._target_plugin = plugin + plugin.parse(inventory, loader, path, cache=cache) try: plugin.update_cache_if_changed() diff --git a/lib/ansible/plugins/inventory/constructed.py b/lib/ansible/plugins/inventory/constructed.py index ee2b9b4295c..6954e3aeab5 100644 --- a/lib/ansible/plugins/inventory/constructed.py +++ b/lib/ansible/plugins/inventory/constructed.py @@ -82,12 +82,11 @@ EXAMPLES = r""" import os from ansible import constants as C -from ansible.errors import AnsibleParserError, AnsibleOptionsError +from ansible.errors import AnsibleParserError from ansible.inventory.helpers import get_group_vars from ansible.plugins.inventory import BaseInventoryPlugin, Constructable -from ansible.module_utils.common.text.converters import to_native +from ansible.plugins.loader import cache_loader from ansible.utils.vars import combine_vars -from ansible.vars.fact_cache import FactCache from ansible.vars.plugins import get_vars_from_inventory_sources @@ -96,11 +95,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable): NAME = 'constructed' - def __init__(self): - - super(InventoryModule, self).__init__() - - self._cache = FactCache() + # implicit trust behavior is already added by the YAML parser invoked by the loader def verify_file(self, path): @@ -147,26 +142,28 @@ class InventoryModule(BaseInventoryPlugin, Constructable): sources = inventory.processed_sources except AttributeError: if self.get_option('use_vars_plugins'): - raise AnsibleOptionsError("The option use_vars_plugins requires ansible >= 2.11.") + raise strict = self.get_option('strict') - fact_cache = FactCache() + + cache = cache_loader.get(C.CACHE_PLUGIN) + try: # Go over hosts (less var copies) for host in inventory.hosts: # get available variables to templar hostvars = self.get_all_host_vars(inventory.hosts[host], loader, sources) - if host in fact_cache: # adds facts if cache is active - hostvars = combine_vars(hostvars, fact_cache[host]) + if cache.contains(host): # adds facts if cache is active + hostvars = combine_vars(hostvars, cache.get(host)) # create composite vars self._set_composite_vars(self.get_option('compose'), hostvars, host, strict=strict) # refetch host vars in case new ones have been created above hostvars = self.get_all_host_vars(inventory.hosts[host], loader, sources) - if host in self._cache: # adds facts if cache is active - hostvars = combine_vars(hostvars, self._cache[host]) + if cache.contains(host): # adds facts if cache is active + hostvars = combine_vars(hostvars, cache.get(host)) # constructed groups based on conditionals self._add_host_to_composed_groups(self.get_option('groups'), hostvars, host, strict=strict, fetch_hostvars=False) @@ -174,5 +171,5 @@ class InventoryModule(BaseInventoryPlugin, Constructable): # constructed groups based variable values self._add_host_to_keyed_groups(self.get_option('keyed_groups'), hostvars, host, strict=strict, fetch_hostvars=False) - except Exception as e: - raise AnsibleParserError("failed to parse %s: %s " % (to_native(path), to_native(e)), orig_exc=e) + except Exception as ex: + raise AnsibleParserError(f"Failed to parse {path!r}.") from ex diff --git a/lib/ansible/plugins/inventory/generator.py b/lib/ansible/plugins/inventory/generator.py index 49c8550403f..ba2570db7d8 100644 --- a/lib/ansible/plugins/inventory/generator.py +++ b/lib/ansible/plugins/inventory/generator.py @@ -84,6 +84,8 @@ class InventoryModule(BaseInventoryPlugin): NAME = 'generator' + # implicit trust behavior is already added by the YAML parser invoked by the loader + def __init__(self): super(InventoryModule, self).__init__() @@ -100,15 +102,18 @@ class InventoryModule(BaseInventoryPlugin): return valid def template(self, pattern, variables): - self.templar.available_variables = variables - return self.templar.do_template(pattern) + # Allow pass-through of data structures for templating later (if applicable). + # This limitation was part of the original plugin implementation and was updated to maintain feature parity with the new templating API. + if not isinstance(pattern, str): + return pattern + + return self.templar.copy_with_new_env(available_variables=variables).template(pattern) def add_parents(self, inventory, child, parents, template_vars): for parent in parents: - try: - groupname = self.template(parent['name'], template_vars) - except (AttributeError, ValueError): - raise AnsibleParserError("Element %s has a parent with no name element" % child['name']) + groupname = self.template(parent.get('name'), template_vars) + if not groupname: + raise AnsibleParserError(f"Element {child} has a parent with no name.") if groupname not in inventory.groups: inventory.add_group(groupname) group = inventory.groups[groupname] diff --git a/lib/ansible/plugins/inventory/host_list.py b/lib/ansible/plugins/inventory/host_list.py index 8cfe9e50aa8..9d4ae2f6fac 100644 --- a/lib/ansible/plugins/inventory/host_list.py +++ b/lib/ansible/plugins/inventory/host_list.py @@ -35,6 +35,8 @@ class InventoryModule(BaseInventoryPlugin): NAME = 'host_list' + # host_list does not set vars, so needs no special trust assistance from the inventory API + def verify_file(self, host_list): valid = False diff --git a/lib/ansible/plugins/inventory/ini.py b/lib/ansible/plugins/inventory/ini.py index cd961bcdb06..0c90a1b1e81 100644 --- a/lib/ansible/plugins/inventory/ini.py +++ b/lib/ansible/plugins/inventory/ini.py @@ -73,7 +73,9 @@ host4 # same host as above, but member of 2 groups, will inherit vars from both """ import ast +import os import re +import typing as t import warnings from ansible.inventory.group import to_safe_group_name @@ -81,6 +83,7 @@ from ansible.plugins.inventory import BaseFileInventoryPlugin from ansible.errors import AnsibleError, AnsibleParserError from ansible.module_utils.common.text.converters import to_bytes, to_text +from ansible._internal._datatag._tags import Origin, TrustedAsTemplate from ansible.utils.shlex import shlex_split @@ -93,18 +96,22 @@ class InventoryModule(BaseFileInventoryPlugin): _COMMENT_MARKERS = frozenset((u';', u'#')) b_COMMENT_MARKERS = frozenset((b';', b'#')) - def __init__(self): + # template trust is applied internally to strings + + def __init__(self) -> None: super(InventoryModule, self).__init__() - self.patterns = {} - self._filename = None + self.patterns: dict[str, re.Pattern] = {} + self._origin: Origin | None = None - def parse(self, inventory, loader, path, cache=True): + def verify_file(self, path): + # hardcode exclusion for TOML to prevent partial parsing of things we know we don't want + return super().verify_file(path) and os.path.splitext(path)[1] != '.toml' - super(InventoryModule, self).parse(inventory, loader, path) + def parse(self, inventory, loader, path: str, cache=True): - self._filename = path + super(InventoryModule, self).parse(inventory, loader, path) try: # Read in the hosts, groups, and variables defined in the inventory file. @@ -132,14 +139,20 @@ class InventoryModule(BaseFileInventoryPlugin): # Non-comment lines still have to be valid uf-8 data.append(to_text(line, errors='surrogate_or_strict')) - self._parse(path, data) - except Exception as e: - raise AnsibleParserError(e) + self._origin = Origin(path=path, line_num=0) + + try: + self._parse(data) + finally: + self._origin = self._origin.replace(line_num=None) + + except Exception as ex: + raise AnsibleParserError('Failed to parse inventory.', obj=self._origin) from ex def _raise_error(self, message): - raise AnsibleError("%s:%d: " % (self._filename, self.lineno) + message) + raise AnsibleError(message) - def _parse(self, path, lines): + def _parse(self, lines): """ Populates self.groups from the given array of lines. Raises an error on any parse failure. @@ -155,9 +168,8 @@ class InventoryModule(BaseFileInventoryPlugin): pending_declarations = {} groupname = 'ungrouped' state = 'hosts' - self.lineno = 0 for line in lines: - self.lineno += 1 + self._origin = self._origin.replace(line_num=self._origin.line_num + 1) line = line.strip() # Skip empty lines and comments @@ -189,7 +201,7 @@ class InventoryModule(BaseFileInventoryPlugin): # declarations will take the appropriate action for a pending child group instead of # incorrectly handling it as a var state pending declaration if state == 'vars' and groupname not in pending_declarations: - pending_declarations[groupname] = dict(line=self.lineno, state=state, name=groupname) + pending_declarations[groupname] = dict(line=self._origin.line_num, state=state, name=groupname) self.inventory.add_group(groupname) @@ -229,7 +241,7 @@ class InventoryModule(BaseFileInventoryPlugin): child = self._parse_group_name(line) if child not in self.inventory.groups: if child not in pending_declarations: - pending_declarations[child] = dict(line=self.lineno, state=state, name=child, parents=[groupname]) + pending_declarations[child] = dict(line=self._origin.line_num, state=state, name=child, parents=[groupname]) else: pending_declarations[child]['parents'].append(groupname) else: @@ -242,10 +254,11 @@ class InventoryModule(BaseFileInventoryPlugin): # We report only the first such error here. for g in pending_declarations: decl = pending_declarations[g] + self._origin = self._origin.replace(line_num=decl['line']) if decl['state'] == 'vars': - raise AnsibleError("%s:%d: Section [%s:vars] not valid for undefined group: %s" % (path, decl['line'], decl['name'], decl['name'])) + raise ValueError(f"Section [{decl['name']}:vars] not valid for undefined group {decl['name']!r}.") elif decl['state'] == 'children': - raise AnsibleError("%s:%d: Section [%s:children] includes undefined group: %s" % (path, decl['line'], decl['parents'].pop(), decl['name'])) + raise ValueError(f"Section [{decl['parents'][-1]}:children] includes undefined group {decl['name']!r}.") def _add_pending_children(self, group, pending): for parent in pending[group]['parents']: @@ -279,7 +292,7 @@ class InventoryModule(BaseFileInventoryPlugin): if '=' in line: (k, v) = [e.strip() for e in line.split("=", 1)] - return (k, self._parse_value(v)) + return (self._origin.tag(k), self._parse_value(v)) self._raise_error("Expected key=value, got: %s" % (line)) @@ -312,7 +325,7 @@ class InventoryModule(BaseFileInventoryPlugin): if '=' not in t: self._raise_error("Expected key=value host variable assignment, got: %s" % (t)) (k, v) = t.split('=', 1) - variables[k] = self._parse_value(v) + variables[self._origin.tag(k)] = self._parse_value(v) return hostnames, port, variables @@ -334,8 +347,27 @@ class InventoryModule(BaseFileInventoryPlugin): return (hostnames, port) - @staticmethod - def _parse_value(v): + def _parse_recursive_coerce_types_and_tag(self, value: t.Any) -> t.Any: + if isinstance(value, str): + return TrustedAsTemplate().tag(self._origin.tag(value)) + if isinstance(value, (list, tuple, set)): + # NB: intentional coercion of tuple/set to list, deal with it + return self._origin.tag([self._parse_recursive_coerce_types_and_tag(v) for v in value]) + if isinstance(value, dict): + # FIXME: enforce keys are strings + return self._origin.tag({self._origin.tag(k): self._parse_recursive_coerce_types_and_tag(v) for k, v in value.items()}) + + if value is ...: # literal_eval parses ellipsis, but it's not a supported variable type + value = TrustedAsTemplate().tag("...") + + if isinstance(value, complex): # convert unsupported variable types recognized by literal_eval back to str + value = TrustedAsTemplate().tag(str(value)) + + value = to_text(value, nonstring='passthru', errors='surrogate_or_strict') + + return self._origin.tag(value) + + def _parse_value(self, v: str) -> t.Any: """ Attempt to transform the string value from an ini file into a basic python object (int, dict, list, unicode string, etc). @@ -352,7 +384,9 @@ class InventoryModule(BaseFileInventoryPlugin): except SyntaxError: # Is this a hash with an equals at the end? pass - return to_text(v, nonstring='passthru', errors='surrogate_or_strict') + + # this is mostly unnecessary, but prevents the (possible) case of bytes literals showing up in inventory + return self._parse_recursive_coerce_types_and_tag(v) def _compile_patterns(self): """ diff --git a/lib/ansible/plugins/inventory/script.py b/lib/ansible/plugins/inventory/script.py index 9c8ecf54541..a0345f638ee 100644 --- a/lib/ansible/plugins/inventory/script.py +++ b/lib/ansible/plugins/inventory/script.py @@ -153,148 +153,136 @@ EXAMPLES = r'''# fmt: code ''' +import json import os +import shlex import subprocess +import typing as t -from collections.abc import Mapping - -from ansible.errors import AnsibleError, AnsibleParserError -from ansible.module_utils.basic import json_dict_bytes_to_unicode -from ansible.module_utils.common.text.converters import to_native, to_text +from ansible.errors import AnsibleError, AnsibleJSONParserError +from ansible.inventory.data import InventoryData +from ansible.module_utils.datatag import native_type_name +from ansible.module_utils.common.json import get_decoder +from ansible.parsing.dataloader import DataLoader from ansible.plugins.inventory import BaseInventoryPlugin +from ansible._internal._datatag._tags import TrustedAsTemplate, Origin from ansible.utils.display import Display +from ansible._internal._json._profiles import _legacy, _inventory_legacy display = Display() class InventoryModule(BaseInventoryPlugin): - """ Host inventory parser for ansible using external inventory scripts. """ + """Host inventory parser for ansible using external inventory scripts.""" NAME = 'script' - def __init__(self): - + def __init__(self) -> None: super(InventoryModule, self).__init__() - self._hosts = set() - - def verify_file(self, path): - """ Verify if file is usable by this plugin, base does minimal accessibility check """ - - valid = super(InventoryModule, self).verify_file(path) - - if valid: - # not only accessible, file must be executable and/or have shebang - shebang_present = False - try: - with open(path, 'rb') as inv_file: - initial_chars = inv_file.read(2) - if initial_chars.startswith(b'#!'): - shebang_present = True - except Exception: - pass + self._hosts: set[str] = set() - if not os.access(path, os.X_OK) and not shebang_present: - valid = False - - return valid - - def parse(self, inventory, loader, path, cache=None): + def verify_file(self, path: str) -> bool: + return super(InventoryModule, self).verify_file(path) and os.access(path, os.X_OK) + def parse(self, inventory: InventoryData, loader: DataLoader, path: str, cache: bool = False) -> None: super(InventoryModule, self).parse(inventory, loader, path) - self.set_options() - # Support inventory scripts that are not prefixed with some - # path information but happen to be in the current working - # directory when '.' is not in PATH. - cmd = [path, "--list"] - - try: - try: - sp = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - except OSError as e: - raise AnsibleParserError("problem running %s (%s)" % (' '.join(cmd), to_native(e))) - (stdout, stderr) = sp.communicate() + self.set_options() - path = to_native(path) - err = to_native(stderr or "") + origin = Origin(description=f'') - if err and not err.endswith('\n'): - err += '\n' + data, stderr, stderr_help_text = run_command(path, ['--list'], origin) - if sp.returncode != 0: - raise AnsibleError("Inventory script (%s) had an execution error: %s " % (path, err)) + try: + profile_name = detect_profile_name(data) + decoder = get_decoder(profile_name) + except Exception as ex: + raise AnsibleError( + message="Unable to get JSON decoder for inventory script result.", + help_text=stderr_help_text, + # obj will be added by inventory manager + ) from ex - # make sure script output is unicode so that json loader will output unicode strings itself + try: try: - data = to_text(stdout, errors="strict") - except Exception as e: - raise AnsibleError("Inventory {0} contained characters that cannot be interpreted as UTF-8: {1}".format(path, to_native(e))) + processed = json.loads(data, cls=decoder) + except Exception as json_ex: + AnsibleJSONParserError.handle_exception(json_ex, origin) + except Exception as ex: + raise AnsibleError( + message="Inventory script result could not be parsed as JSON.", + help_text=stderr_help_text, + # obj will be added by inventory manager + ) from ex + + # if no other errors happened, and you want to force displaying stderr, do so now + if stderr and self.get_option('always_show_stderr'): + self.display.error(msg=stderr) + + data_from_meta: dict | None = None + + # A "_meta" subelement may contain a variable "hostvars" which contains a hash for each host + # if this "hostvars" exists at all then do not call --host for each # host. + # This is for efficiency and scripts should still return data + # if called with --host for backwards compat with 1.2 and earlier. + for (group, gdata) in processed.items(): + if group == '_meta': + data_from_meta = gdata.get('hostvars') + + if not isinstance(data_from_meta, dict): + raise TypeError(f"Value contains '_meta.hostvars' which is {native_type_name(data_from_meta)!r} instead of {native_type_name(dict)!r}.") + else: + self._parse_group(group, gdata, origin) + + if data_from_meta is None: + display.deprecated( + msg="Inventory scripts should always provide 'meta.hostvars'. " + "Host variables will be collected by running the inventory script with the '--host' option for each host.", + version='2.23', + obj=origin, + ) + + for host in self._hosts: + if data_from_meta is None: + got = self.get_host_variables(path, host, origin) + else: + got = data_from_meta.get(host, {}) - try: - processed = self.loader.load(data, json_only=True) - except Exception as e: - raise AnsibleError("failed to parse executable inventory script results from {0}: {1}\n{2}".format(path, to_native(e), err)) - - # if no other errors happened and you want to force displaying stderr, do so now - if stderr and self.get_option('always_show_stderr'): - self.display.error(msg=to_text(err)) - - if not isinstance(processed, Mapping): - raise AnsibleError("failed to parse executable inventory script results from {0}: needs to be a json dict\n{1}".format(path, err)) - - group = None - data_from_meta = None - - # A "_meta" subelement may contain a variable "hostvars" which contains a hash for each host - # if this "hostvars" exists at all then do not call --host for each # host. - # This is for efficiency and scripts should still return data - # if called with --host for backwards compat with 1.2 and earlier. - for (group, gdata) in processed.items(): - if group == '_meta': - if 'hostvars' in gdata: - data_from_meta = gdata['hostvars'] - else: - self._parse_group(group, gdata) - - for host in self._hosts: - got = {} - if data_from_meta is None: - got = self.get_host_variables(path, host) - else: - try: - got = data_from_meta.get(host, {}) - except AttributeError as e: - raise AnsibleError("Improperly formatted host information for %s: %s" % (host, to_native(e)), orig_exc=e) - - self._populate_host_vars([host], got) - - except Exception as e: - raise AnsibleParserError(to_native(e)) - - def _parse_group(self, group, data): + self._populate_host_vars([host], got) + def _parse_group(self, group: str, data: t.Any, origin: Origin) -> None: + """Normalize and ingest host/var information for the named group.""" group = self.inventory.add_group(group) if not isinstance(data, dict): data = {'hosts': data} - # is not those subkeys, then simplified syntax, host with vars + display.deprecated( + msg=f"Group {group!r} was converted to {native_type_name(dict)!r} from {native_type_name(data)!r}.", + version='2.23', + obj=origin, + ) elif not any(k in data for k in ('hosts', 'vars', 'children')): data = {'hosts': [group], 'vars': data} + display.deprecated( + msg=f"Treating malformed group {group!r} as the sole host of that group. Variables provided in this manner cannot be templated.", + version='2.23', + obj=origin, + ) - if 'hosts' in data: - if not isinstance(data['hosts'], list): - raise AnsibleError("You defined a group '%s' with bad data for the host list:\n %s" % (group, data)) + if (data_hosts := data.get('hosts', ...)) is not ...: + if not isinstance(data_hosts, list): + raise TypeError(f"Value contains '{group}.hosts' which is {native_type_name(data_hosts)!r} instead of {native_type_name(list)!r}.") - for hostname in data['hosts']: + for hostname in data_hosts: self._hosts.add(hostname) self.inventory.add_host(hostname, group) - if 'vars' in data: - if not isinstance(data['vars'], dict): - raise AnsibleError("You defined a group '%s' with bad data for variables:\n %s" % (group, data)) + if (data_vars := data.get('vars', ...)) is not ...: + if not isinstance(data_vars, dict): + raise TypeError(f"Value contains '{group}.vars' which is {native_type_name(data_vars)!r} instead of {native_type_name(dict)!r}.") - for k, v in data['vars'].items(): + for k, v in data_vars.items(): self.inventory.set_variable(group, k, v) if group != '_meta' and isinstance(data, dict) and 'children' in data: @@ -302,22 +290,102 @@ class InventoryModule(BaseInventoryPlugin): child_name = self.inventory.add_group(child_name) self.inventory.add_child(group, child_name) - def get_host_variables(self, path, host): - """ Runs