Clean up TE error handling, wrap sigalrm handler (#85232)

* Clean up TE error handling, wrap sigalrm handler

* Preserve error detail on AnsibleAction and Connection exceptions.
* Remove multiple layers of unreachable or redundant error handling.
* Wrap manual alarm signal/timeout handling into a context manager, add tests.

Co-authored-by: Matt Clay <matt@mystile.com>

* update error message check in test

* update test timeout message assertions

---------

Co-authored-by: Matt Clay <matt@mystile.com>
(cherry picked from commit cbcefc53a3)
pull/85255/head
Matt Davis 6 months ago committed by Matt Davis
parent 311ef75245
commit 85283e7238

@ -0,0 +1,7 @@
bugfixes:
- task timeout - Specifying a negative task timeout now results in an error.
- error handling - Error details and tracebacks from connection and built-in action exceptions are preserved.
Previously, much of the detail was lost or mixed into the error message.
minor_changes:
- task timeout - Specifying a timeout greater than 100,000,000 now results in an error.

@ -0,0 +1,66 @@
from __future__ import annotations
import contextlib
import signal
import types
import typing as _t
from ansible.module_utils import datatag
class AnsibleTimeoutError(BaseException):
"""A general purpose timeout."""
_MAX_TIMEOUT = 100_000_000
"""
The maximum supported timeout value.
This value comes from BSD's alarm limit, which is due to that function using setitimer.
"""
def __init__(self, timeout: int) -> None:
self.timeout = timeout
super().__init__(f"Timed out after {timeout} second(s).")
@classmethod
@contextlib.contextmanager
def alarm_timeout(cls, timeout: int | None) -> _t.Iterator[None]:
"""
Context for running code under an optional timeout.
Raises an instance of this class if the timeout occurs.
New usages of this timeout mechanism are discouraged.
"""
if timeout is not None:
if not isinstance(timeout, int):
raise TypeError(f"Timeout requires 'int' argument, not {datatag.native_type_name(timeout)!r}.")
if timeout < 0 or timeout > cls._MAX_TIMEOUT:
# On BSD based systems, alarm is implemented using setitimer.
# If out-of-bounds values are passed to alarm, they will return -1, which would be interpreted as an existing timer being set.
# To avoid that, bounds checking is performed in advance.
raise ValueError(f'Timeout {timeout} is invalid, it must be between 0 and {cls._MAX_TIMEOUT}.')
if not timeout:
yield # execute the context manager's body
return # no timeout to deal with, exit immediately
def on_alarm(_signal: int, _frame: types.FrameType) -> None:
raise cls(timeout)
if signal.signal(signal.SIGALRM, on_alarm):
raise RuntimeError("An existing alarm handler was present.")
try:
try:
if signal.alarm(timeout):
raise RuntimeError("An existing alarm was set.")
yield # execute the context manager's body
finally:
# Disable the alarm.
# If the alarm fires inside this finally block, the alarm is still disabled.
# This guarantees the cleanup code in the outer finally block runs without risk of encountering the `TaskTimeoutError` from the alarm.
signal.alarm(0)
finally:
signal.signal(signal.SIGALRM, signal.SIG_DFL)

@ -1,8 +1,10 @@
from __future__ import annotations
import collections.abc as _c
import dataclasses
import typing as t
from ansible._internal._errors import _error_utils
from ansible.errors import AnsibleRuntimeError
from ansible.module_utils._internal import _messages
@ -25,31 +27,28 @@ class AnsibleCapturedError(AnsibleRuntimeError):
self._event = event
class AnsibleResultCapturedError(AnsibleCapturedError):
"""An exception representing error detail captured in a foreign context where an action/module result dictionary is involved."""
class AnsibleResultCapturedError(AnsibleCapturedError, _error_utils.ContributesToTaskResult):
"""
An exception representing error detail captured in a foreign context where an action/module result dictionary is involved.
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
def __init__(self, event: _messages.Event, result: dict[str, t.Any]) -> None:
super().__init__(event=event)
self._result = result
@property
def result_contribution(self) -> _c.Mapping[str, object]:
return self._result
@classmethod
def maybe_raise_on_result(cls, result: dict[str, t.Any]) -> None:
"""Normalize the result and raise an exception if the result indicated failure."""
if error_summary := cls.normalize_result_exception(result):
raise error_summary.error_type(error_summary.event, result)
@classmethod
def find_first_remoted_error(cls, exception: BaseException) -> t.Self | None:
"""Find the first captured module error in the cause chain, starting with the given exception, returning None if not found."""
while exception:
if isinstance(exception, cls):
return exception
exception = exception.__cause__
return None
@classmethod
def normalize_result_exception(cls, result: dict[str, t.Any]) -> CapturedErrorSummary | None:
"""

@ -1,5 +1,7 @@
from __future__ import annotations
import abc
import collections.abc as _c
import dataclasses
import itertools
import pathlib
@ -8,18 +10,30 @@ import typing as t
from ansible._internal._datatag._tags import Origin
from ansible._internal._errors import _error_factory
from ansible.module_utils._internal import _ambient_context, _event_utils
from ansible.module_utils._internal import _ambient_context, _event_utils, _messages, _traceback
class RedactAnnotatedSourceContext(_ambient_context.AmbientContextBase):
"""
When active, this context will redact annotated source lines, showing only the origin.
"""
class ContributesToTaskResult(metaclass=abc.ABCMeta):
"""Exceptions may include this mixin to contribute task result dictionary data directly to the final result."""
@property
@abc.abstractmethod
def result_contribution(self) -> _c.Mapping[str, object]:
"""Mapping of results to apply to the task result."""
def format_exception_message(exception: BaseException) -> str:
"""Return the full chain of exception messages by concatenating the cause(s) until all are exhausted."""
return _event_utils.format_event_brief_message(_error_factory.ControllerEventFactory.from_exception(exception, False))
@property
def omit_exception_key(self) -> bool:
"""Non-error exceptions (e.g., `AnsibleActionSkip`) must return `True` to ensure omission of the `exception` key."""
return False
@property
def omit_failed_key(self) -> bool:
"""Exceptions representing non-failure scenarios (e.g., `skipped`, `unreachable`) must return `True` to ensure omisson of the `failed` key."""
return False
class RedactAnnotatedSourceContext(_ambient_context.AmbientContextBase):
"""When active, this context will redact annotated source lines, showing only the origin."""
@dataclasses.dataclass(kw_only=True, frozen=True)
@ -159,3 +173,68 @@ class SourceContext:
annotated_source_lines=annotated_source_lines,
target_line=lines[-1].rstrip('\n'), # universal newline default mode on `open` ensures we'll never see anything but \n
)
def format_exception_message(exception: BaseException) -> str:
"""Return the full chain of exception messages by concatenating the cause(s) until all are exhausted."""
return _event_utils.format_event_brief_message(_error_factory.ControllerEventFactory.from_exception(exception, False))
def result_dict_from_exception(exception: BaseException, accept_result_contribution: bool = False) -> dict[str, object]:
"""Return a failed task result dict from the given exception."""
event = _error_factory.ControllerEventFactory.from_exception(exception, _traceback.is_traceback_enabled(_traceback.TracebackEvent.ERROR))
result: dict[str, object] = {}
omit_failed_key = False
omit_exception_key = False
if accept_result_contribution:
while exception:
if isinstance(exception, ContributesToTaskResult):
result = dict(exception.result_contribution)
omit_failed_key = exception.omit_failed_key
omit_exception_key = exception.omit_exception_key
break
exception = exception.__cause__
if omit_failed_key:
result.pop('failed', None)
else:
result.update(failed=True)
if omit_exception_key:
result.pop('exception', None)
else:
result.update(exception=_messages.ErrorSummary(event=event))
if 'msg' not in result:
# if nothing contributed `msg`, generate one from the exception messages
result.update(msg=_event_utils.format_event_brief_message(event))
return result
def result_dict_from_captured_errors(
msg: str,
*,
errors: list[_messages.ErrorSummary] | None = None,
) -> dict[str, object]:
"""Return a failed task result dict from the given error message and captured errors."""
_skip_stackwalk = True
event = _messages.Event(
msg=msg,
formatted_traceback=_traceback.maybe_capture_traceback(msg, _traceback.TracebackEvent.ERROR),
events=tuple(error.event for error in errors) if errors else None,
)
result = dict(
failed=True,
exception=_messages.ErrorSummary(
event=event,
),
msg=_event_utils.format_event_brief_message(event),
)
return result

@ -0,0 +1,28 @@
from __future__ import annotations
from collections import abc as _c
from ansible._internal._errors._alarm_timeout import AnsibleTimeoutError
from ansible._internal._errors._error_utils import ContributesToTaskResult
from ansible.module_utils.datatag import deprecate_value
class TaskTimeoutError(AnsibleTimeoutError, ContributesToTaskResult):
"""
A task-specific timeout.
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
@property
def result_contribution(self) -> _c.Mapping[str, object]:
help_text = "Configure `DISPLAY_TRACEBACK` to see a traceback on timeout errors."
frame = deprecate_value(
value=help_text,
msg="The `timedout.frame` task result key is deprecated.",
help_text=help_text,
version="2.23",
)
return dict(timedout=dict(frame=frame, period=self.timeout))

@ -7,7 +7,7 @@ import typing as t
from yaml import MarkedYAMLError
from yaml.constructor import ConstructorError
from ansible._internal._errors import _utils
from ansible._internal._errors import _error_utils
from ansible.errors import AnsibleParserError
from ansible._internal._datatag._tags import Origin
@ -34,7 +34,7 @@ class AnsibleYAMLParserError(AnsibleParserError):
if isinstance(exception, MarkedYAMLError):
origin = origin.replace(line_num=exception.problem_mark.line + 1, col_num=exception.problem_mark.column + 1)
source_context = _utils.SourceContext.from_origin(origin)
source_context = _error_utils.SourceContext.from_origin(origin)
target_line = source_context.target_line or '' # for these cases, we don't need to distinguish between None and empty string
@ -66,12 +66,12 @@ class AnsibleYAMLParserError(AnsibleParserError):
# There may be cases where there is a valid tab in a line that has other errors.
# That's OK, users should "fix" their tab usage anyway -- at which point later error handling logic will hopefully find the real issue.
elif (tab_idx := target_line.find('\t')) >= 0:
source_context = _utils.SourceContext.from_origin(origin.replace(col_num=tab_idx + 1))
source_context = _error_utils.SourceContext.from_origin(origin.replace(col_num=tab_idx + 1))
message = "Tabs are usually invalid in YAML."
# Check for unquoted templates.
elif match := re.search(r'^\s*(?:-\s+)*(?:[\w\s]+:\s+)?(?P<value>\{\{.*}})', target_line):
source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
source_context = _error_utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
message = 'This may be an issue with missing quotes around a template block.'
# FIXME: Use the captured value to show the actual fix required.
help_text = """
@ -95,7 +95,7 @@ Should be:
# look for an unquoted colon in the value
and (colon_match := re.search(r':($| )', target_fragment))
):
source_context = _utils.SourceContext.from_origin(origin.replace(col_num=value_match.start('value') + colon_match.start() + 1))
source_context = _error_utils.SourceContext.from_origin(origin.replace(col_num=value_match.start('value') + colon_match.start() + 1))
message = 'Colons in unquoted values must be followed by a non-space character.'
# FIXME: Use the captured value to show the actual fix required.
help_text = """
@ -114,7 +114,7 @@ Should be:
first, last = suspected_value[0], suspected_value[-1]
if first != last: # "foo" in bar
source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
source_context = _error_utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
message = 'Values starting with a quote must end with the same quote.'
# FIXME: Use the captured value to show the actual fix required, and use that same logic to improve the origin further.
help_text = """
@ -127,7 +127,7 @@ Should be:
raw: '"foo" in bar'
"""
elif first == last and target_line.count(first) > 2: # "foo" and "bar"
source_context = _utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
source_context = _error_utils.SourceContext.from_origin(origin.replace(col_num=match.start('value') + 1))
message = 'Values starting with a quote must end with the same quote, and not contain that quote.'
# FIXME: Use the captured value to show the actual fix required, and use that same logic to improve the origin further.
help_text = """

@ -3,20 +3,17 @@
from __future__ import annotations
import collections.abc as _c
import enum
import traceback
import sys
import types
import typing as t
from collections.abc import Sequence
from json import JSONDecodeError
from ansible.module_utils.common.text.converters import to_text
from ..module_utils.datatag import native_type_name
from ansible._internal._datatag import _tags
from .._internal._errors import _utils
from .._internal._errors import _error_utils
from ansible.module_utils._internal import _text_utils
if t.TYPE_CHECKING:
@ -112,7 +109,7 @@ class AnsibleError(Exception):
Return the original message with cause message(s) appended.
The cause will not be followed on any `AnsibleError` with `_include_cause_message=False`.
"""
return _utils.format_exception_message(self)
return _error_utils.format_exception_message(self)
@message.setter
def message(self, val) -> None:
@ -120,8 +117,8 @@ class AnsibleError(Exception):
@property
def _formatted_source_context(self) -> str | None:
with _utils.RedactAnnotatedSourceContext.when(not self._show_content):
if source_context := _utils.SourceContext.from_value(self.obj):
with _error_utils.RedactAnnotatedSourceContext.when(not self._show_content):
if source_context := _error_utils.SourceContext.from_value(self.obj):
return str(source_context)
return None
@ -237,8 +234,20 @@ class AnsibleModuleError(AnsibleRuntimeError):
"""A module failed somehow."""
class AnsibleConnectionFailure(AnsibleRuntimeError):
"""The transport / connection_plugin had a fatal error."""
class AnsibleConnectionFailure(AnsibleRuntimeError, _error_utils.ContributesToTaskResult):
"""
The transport / connection_plugin had a fatal error.
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
@property
def result_contribution(self) -> t.Mapping[str, object]:
return dict(unreachable=True)
@property
def omit_failed_key(self) -> bool:
return True
class AnsibleAuthenticationFailure(AnsibleConnectionFailure):
@ -318,7 +327,7 @@ class AnsibleFileNotFound(AnsibleRuntimeError):
else:
message += "Could not find file"
if self.paths and isinstance(self.paths, Sequence):
if self.paths and isinstance(self.paths, _c.Sequence):
searched = to_text('\n\t'.join(self.paths))
if message:
message += "\n"
@ -330,47 +339,76 @@ class AnsibleFileNotFound(AnsibleRuntimeError):
suppress_extended_error=suppress_extended_error, orig_exc=orig_exc)
# These Exceptions are temporary, using them as flow control until we can get a better solution.
# DO NOT USE as they will probably be removed soon.
# We will port the action modules in our tree to use a context manager instead.
class AnsibleAction(AnsibleRuntimeError):
class AnsibleAction(AnsibleRuntimeError, _error_utils.ContributesToTaskResult):
"""Base Exception for Action plugin flow control."""
def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None):
super(AnsibleAction, self).__init__(message=message, obj=obj, show_content=show_content,
suppress_extended_error=suppress_extended_error, orig_exc=orig_exc)
if result is None:
self.result = {}
else:
self.result = result
super().__init__(message=message, obj=obj, show_content=show_content, suppress_extended_error=suppress_extended_error, orig_exc=orig_exc)
self._result = result or {}
@property
def result_contribution(self) -> _c.Mapping[str, object]:
return self._result
@property
def result(self) -> dict[str, object]:
"""Backward compatibility property returning a mutable dictionary."""
return dict(self.result_contribution)
class AnsibleActionSkip(AnsibleAction):
"""An action runtime skip."""
"""
An action runtime skip.
def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None):
super(AnsibleActionSkip, self).__init__(message=message, obj=obj, show_content=show_content,
suppress_extended_error=suppress_extended_error, orig_exc=orig_exc, result=result)
self.result.update({'skipped': True, 'msg': message})
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
@property
def result_contribution(self) -> _c.Mapping[str, object]:
return self._result | dict(
skipped=True,
msg=self.message,
)
@property
def omit_failed_key(self) -> bool:
return True
@property
def omit_exception_key(self) -> bool:
return True
class AnsibleActionFail(AnsibleAction):
"""An action runtime failure."""
"""
An action runtime failure.
def __init__(self, message="", obj=None, show_content=True, suppress_extended_error=..., orig_exc=None, result=None):
super(AnsibleActionFail, self).__init__(message=message, obj=obj, show_content=show_content,
suppress_extended_error=suppress_extended_error, orig_exc=orig_exc, result=result)
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
@property
def result_contribution(self) -> _c.Mapping[str, object]:
return self._result | dict(
failed=True,
msg=self.message,
)
result_overrides = {'failed': True, 'msg': message}
# deprecated: description='use sys.exception()' python_version='3.11'
if sys.exc_info()[1]: # DTFIX-FUTURE: remove this hack once TaskExecutor is no longer shucking AnsibleActionFail and returning its result
result_overrides['exception'] = traceback.format_exc()
self.result.update(result_overrides)
class _ActionDone(AnsibleAction):
"""
Imports as `_AnsibleActionDone` are deprecated. An action runtime early exit.
This exception provides a result dictionary via the ContributesToTaskResult mixin.
"""
@property
def omit_failed_key(self) -> bool:
return not self._result.get('failed')
class _AnsibleActionDone(AnsibleAction):
"""An action runtime early exit."""
@property
def omit_exception_key(self) -> bool:
return not self._result.get('failed')
class AnsiblePluginError(AnsibleError):
@ -421,13 +459,23 @@ def __getattr__(name: str) -> t.Any:
"""Inject import-time deprecation warnings."""
from ..utils.display import Display
if name == 'AnsibleFilterTypeError':
Display().deprecated(
msg="Importing 'AnsibleFilterTypeError' is deprecated.",
help_text=f"Import {AnsibleTypeError.__name__!r} instead.",
version="2.23",
)
match name:
case 'AnsibleFilterTypeError':
Display().deprecated(
msg=f"Importing {name!r} is deprecated.",
help_text=f"Import {AnsibleTypeError.__name__!r} instead.",
version="2.23",
)
return AnsibleTypeError
case '_AnsibleActionDone':
Display().deprecated(
msg=f"Importing {name!r} is deprecated.",
help_text="Return directly from action plugins instead.",
version="2.23",
)
return AnsibleTypeError
return _ActionDone
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')

@ -25,29 +25,25 @@ import textwrap
import traceback
import types
import typing as t
from multiprocessing.queues import Queue
from ansible import context
from ansible._internal import _task
from ansible.errors import AnsibleConnectionFailure, AnsibleError
from ansible._internal._errors import _error_utils
from ansible.errors import AnsibleError
from ansible.executor.task_executor import TaskExecutor
from ansible.executor.task_queue_manager import FinalQueue, STDIN_FILENO, STDOUT_FILENO, STDERR_FILENO
from ansible.executor.task_result import _RawTaskResult
from ansible.inventory.host import Host
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.common.text.converters import to_text
from ansible.parsing.dataloader import DataLoader
from ansible.playbook.task import Task
from ansible.playbook.play_context import PlayContext
from ansible.plugins.loader import init_plugin_loader
from ansible.utils.context_objects import CLIArgs
from ansible.plugins.action import ActionBase
from ansible.utils.display import Display
from ansible.utils.multiprocessing import context as multiprocessing_context
from ansible.vars.manager import VariableManager
from jinja2.exceptions import TemplateNotFound
__all__ = ['WorkerProcess']
display = Display()
@ -204,120 +200,49 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
signify that they are ready for their next task.
"""
# import cProfile, pstats, StringIO
# pr = cProfile.Profile()
# pr.enable()
global current_worker
current_worker = self
if multiprocessing_context.get_start_method() != 'fork':
# This branch is unused currently, as we hardcode fork
# TODO
# * move into a setup func run in `run`, before `_detach`
# * playbook relative content
# * display verbosity
# * ???
context.CLIARGS = self._cliargs
# Initialize plugin loader after parse, so that the init code can utilize parsed arguments
cli_collections_path = context.CLIARGS.get('collections_path') or []
if not is_sequence(cli_collections_path):
# In some contexts ``collections_path`` is singular
cli_collections_path = [cli_collections_path]
init_plugin_loader(cli_collections_path)
executor_result = TaskExecutor(
self._host,
self._task,
self._task_vars,
self._play_context,
self._loader,
self._shared_loader_obj,
self._final_q,
self._variable_manager,
).run()
self._host.vars = dict()
self._host.groups = []
for name, stdio in (('stdout', sys.stdout), ('stderr', sys.stderr)):
if data := stdio.getvalue(): # type: ignore[union-attr]
display.warning(
(
f'WorkerProcess for [{self._host}/{self._task}] errantly sent data directly to {name} instead of using Display:\n'
f'{textwrap.indent(data[:256], " ")}\n'
),
formatted=True
)
try:
# execute the task and build a _RawTaskResult from the result
display.debug("running TaskExecutor() for %s/%s" % (self._host, self._task))
executor_result = TaskExecutor(
self._host,
self._task,
self._task_vars,
self._play_context,
self._loader,
self._shared_loader_obj,
self._final_q,
self._variable_manager,
).run()
display.debug("done running TaskExecutor() for %s/%s [%s]" % (self._host, self._task, self._task._uuid))
self._host.vars = dict()
self._host.groups = []
for name, stdio in (('stdout', sys.stdout), ('stderr', sys.stderr)):
if data := stdio.getvalue(): # type: ignore[union-attr]
display.warning(
(
f'WorkerProcess for [{self._host}/{self._task}] errantly sent data directly to {name} instead of using Display:\n'
f'{textwrap.indent(data[:256], " ")}\n'
),
formatted=True
)
# put the result on the result queue
display.debug("sending task result for task %s" % self._task._uuid)
try:
self._final_q.send_task_result(_RawTaskResult(
host=self._host,
task=self._task,
return_data=executor_result,
task_fields=self._task.dump_attrs(),
))
except Exception as ex:
try:
raise AnsibleError("Task result omitted due to queue send failure.") from ex
except Exception as ex_wrapper:
self._final_q.send_task_result(_RawTaskResult(
host=self._host,
task=self._task,
return_data=ActionBase.result_dict_from_exception(ex_wrapper), # Overriding the task result, to represent the failure
task_fields={}, # The failure pickling may have been caused by the task attrs, omit for safety
))
display.debug("done sending task result for task %s" % self._task._uuid)
except AnsibleConnectionFailure as ex:
return_data = ActionBase.result_dict_from_exception(ex)
return_data.pop('failed')
return_data.update(unreachable=True)
self._host.vars = dict()
self._host.groups = []
self._final_q.send_task_result(_RawTaskResult(
host=self._host,
task=self._task,
return_data=return_data,
return_data=executor_result,
task_fields=self._task.dump_attrs(),
))
except Exception as ex:
if not isinstance(ex, (IOError, EOFError, KeyboardInterrupt, SystemExit)) or isinstance(ex, TemplateNotFound):
try:
self._host.vars = dict()
self._host.groups = []
self._final_q.send_task_result(_RawTaskResult(
host=self._host,
task=self._task,
return_data=ActionBase.result_dict_from_exception(ex),
task_fields=self._task.dump_attrs(),
))
except Exception:
display.debug(u"WORKER EXCEPTION: %s" % to_text(ex))
display.debug(u"WORKER TRACEBACK: %s" % to_text(traceback.format_exc()))
finally:
self._clean_up()
display.debug("WORKER PROCESS EXITING")
# pr.disable()
# s = StringIO.StringIO()
# sortby = 'time'
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
# ps.print_stats()
# with open('worker_%06d.stats' % os.getpid(), 'w') as f:
# f.write(s.getvalue())
def _clean_up(self) -> None:
# NOTE: see note in init about forks
# ensure we cleanup all temp files for this worker
self._loader.cleanup_all_tmp_files()
try:
raise AnsibleError("Task result omitted due to queue send failure.") from ex
except Exception as ex_wrapper:
self._final_q.send_task_result(_RawTaskResult(
host=self._host,
task=self._task,
# ignore the real task result and don't allow result object contribution from the exception (in case the pickling error was related)
return_data=_error_utils.result_dict_from_exception(ex_wrapper),
task_fields={}, # The failure pickling may have been caused by the task attrs, omit for safety
))

@ -7,7 +7,6 @@ import os
import time
import json
import pathlib
import signal
import subprocess
import sys
@ -17,7 +16,7 @@ import typing as t
from ansible import constants as C
from ansible.cli import scripts
from ansible.errors import (
AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip, AnsibleTaskError,
AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleTaskError,
AnsibleValueOmittedError,
)
from ansible.executor.task_result import _RawTaskResult
@ -31,7 +30,6 @@ from ansible.module_utils.connection import write_to_stream
from ansible.module_utils.six import string_types
from ansible.playbook.task import Task
from ansible.plugins import get_plugin_class
from ansible.plugins.action import ActionBase
from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader
from ansible._internal._templating._jinja_plugins import _invoke_lookup, _DirectCall
from ansible._internal._templating._engine import TemplateEngine
@ -41,7 +39,7 @@ from ansible.utils.display import Display, _DeferredWarningContext
from ansible.utils.vars import combine_vars
from ansible.vars.clean import namespace_facts, clean_facts
from ansible.vars.manager import _deprecate_top_level_fact
from ansible._internal._errors import _captured
from ansible._internal._errors import _captured, _task_timeout, _error_utils
if t.TYPE_CHECKING:
from ansible.executor.task_queue_manager import FinalQueue
@ -54,24 +52,6 @@ RETURN_VARS = [x for x in C.MAGIC_VARIABLE_MAPPING.items() if 'become' not in x
__all__ = ['TaskExecutor']
class TaskTimeoutError(BaseException):
def __init__(self, message="", frame=None):
if frame is not None:
orig = frame
root = pathlib.Path(__file__).parent
while not pathlib.Path(frame.f_code.co_filename).is_relative_to(root):
frame = frame.f_back
self.frame = 'Interrupted at %s called from %s' % (orig, frame)
super(TaskTimeoutError, self).__init__(message)
def task_timeout(signum, frame):
raise TaskTimeoutError(frame=frame)
class TaskExecutor:
"""
@ -176,7 +156,7 @@ class TaskExecutor:
return res
except Exception as ex:
result = ActionBase.result_dict_from_exception(ex)
result = _error_utils.result_dict_from_exception(ex)
self._task.update_result_no_log(self._task_templar, result)
@ -442,11 +422,11 @@ class TaskExecutor:
result = self._execute_internal(templar, variables)
self._apply_task_result_compat(result, warning_ctx)
_captured.AnsibleActionCapturedError.maybe_raise_on_result(result)
except Exception as ex:
except (Exception, _task_timeout.TaskTimeoutError) as ex: # TaskTimeoutError is BaseException
try:
raise AnsibleTaskError(obj=self._task.get_ds()) from ex
except AnsibleTaskError as atex:
result = ActionBase.result_dict_from_exception(atex)
result = _error_utils.result_dict_from_exception(atex, accept_result_contribution=True)
result.setdefault('changed', False)
self._task.update_result_no_log(templar, result)
@ -636,24 +616,9 @@ class TaskExecutor:
for attempt in range(1, retries + 1):
display.debug("running the handler")
try:
if self._task.timeout:
old_sig = signal.signal(signal.SIGALRM, task_timeout)
signal.alarm(self._task.timeout)
result = self._handler.run(task_vars=vars_copy)
# DTFIX0: nuke this, it hides a lot of error detail- remove the active exception propagation hack from AnsibleActionFail at the same time
except (AnsibleActionFail, AnsibleActionSkip) as e:
return e.result
except AnsibleConnectionFailure as e:
return dict(unreachable=True, msg=to_text(e))
except TaskTimeoutError as e:
msg = 'The %s action failed to execute in the expected time frame (%d) and was terminated' % (self._task.action, self._task.timeout)
return dict(failed=True, msg=msg, timedout={'frame': e.frame, 'period': self._task.timeout})
with _task_timeout.TaskTimeoutError.alarm_timeout(self._task.timeout):
result = self._handler.run(task_vars=vars_copy)
finally:
if self._task.timeout:
signal.alarm(0)
old_sig = signal.signal(signal.SIGALRM, old_sig)
self._handler.cleanup()
display.debug("handler run complete")

@ -15,7 +15,7 @@ import typing as t
from ansible import constants as C
from ansible.errors import AnsibleFileNotFound, AnsibleParserError
from ansible._internal._errors import _utils
from ansible._internal._errors import _error_utils
from ansible.module_utils.basic import is_executable
from ansible._internal._datatag._tags import Origin, TrustedAsTemplate, SourceWasEncrypted
from ansible.module_utils._internal._datatag import AnsibleTagHelper
@ -86,7 +86,7 @@ class DataLoader:
json_only: bool = False,
) -> t.Any:
"""Backwards compat for now"""
with _utils.RedactAnnotatedSourceContext.when(not show_content):
with _error_utils.RedactAnnotatedSourceContext.when(not show_content):
return from_yaml(data=data, file_name=file_name, json_only=json_only)
def load_from_file(self, file_name: str, cache: str = 'all', unsafe: bool = False, json_only: bool = False, trusted_as_template: bool = False) -> t.Any:

@ -11,7 +11,7 @@ import typing as t
import yaml
from ansible.errors import AnsibleJSONParserError
from ansible._internal._errors import _utils
from ansible._internal._errors import _error_utils
from ansible.parsing.vault import VaultSecret
from ansible.parsing.yaml.loader import AnsibleLoader
from ansible._internal._yaml._errors import AnsibleYAMLParserError
@ -34,7 +34,7 @@ def from_yaml(
data = origin.tag(data)
with _utils.RedactAnnotatedSourceContext.when(not show_content):
with _error_utils.RedactAnnotatedSourceContext.when(not show_content):
try:
# we first try to load this data as JSON.
# Fixes issues with extra vars json strings not being parsed correctly by the yaml parser

@ -20,11 +20,11 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence
from ansible import constants as C
from ansible._internal._errors import _captured, _error_factory
from ansible._internal._errors import _captured, _error_utils
from ansible.errors import AnsibleError, AnsibleConnectionFailure, AnsibleActionSkip, AnsibleActionFail, AnsibleAuthenticationFailure
from ansible.executor.module_common import modify_module, _BuiltModule
from ansible.executor.interpreter_discovery import discover_interpreter, InterpreterDiscoveryRequiredError
from ansible.module_utils._internal import _traceback, _event_utils, _messages
from ansible.module_utils._internal import _traceback
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.errors import UnsupportedError
from ansible.module_utils.json_utils import _filter_non_json_lines
@ -1252,7 +1252,7 @@ class ActionBase(ABC, _AnsiblePluginInfoMixin):
except AnsibleError as ansible_ex:
sentinel = object()
data = self.result_dict_from_exception(ansible_ex)
data = _error_utils.result_dict_from_exception(ansible_ex)
data.update(
_ansible_parsed=False,
module_stdout=res.get('stdout', ''),
@ -1433,50 +1433,3 @@ class ActionBase(ABC, _AnsiblePluginInfoMixin):
# if missing it will return a file not found exception
return self._loader.path_dwim_relative_stack(path_stack, dirname, needle)
@staticmethod
def result_dict_from_exception(exception: BaseException) -> dict[str, t.Any]:
"""Return a failed task result dict from the given exception."""
if ansible_remoted_error := _captured.AnsibleResultCapturedError.find_first_remoted_error(exception):
result = ansible_remoted_error._result.copy()
else:
result = {}
event = _error_factory.ControllerEventFactory.from_exception(exception, _traceback.is_traceback_enabled(_traceback.TracebackEvent.ERROR))
result.update(
failed=True,
exception=_messages.ErrorSummary(
event=event,
),
)
if 'msg' not in result:
result.update(msg=_event_utils.format_event_brief_message(event))
return result
def _result_dict_from_captured_errors(
self,
msg: str,
*,
errors: list[_messages.ErrorSummary] | None = None,
) -> dict[str, t.Any]:
"""Return a failed task result dict from the given error message and captured errors."""
_skip_stackwalk = True
event = _messages.Event(
msg=msg,
formatted_traceback=_traceback.maybe_capture_traceback(msg, _traceback.TracebackEvent.ERROR),
events=tuple(error.event for error in errors) if errors else None,
)
result = dict(
failed=True,
exception=_messages.ErrorSummary(
event=event,
),
msg=_event_utils.format_event_brief_message(event),
)
return result

@ -77,7 +77,7 @@ class ActionModule(ActionBase):
elif isinstance(groups, string_types):
group_list = groups.split(",")
else:
raise AnsibleActionFail("Groups must be specified as a list.", obj=self._task)
raise AnsibleActionFail("Groups must be specified as a list.", obj=groups)
for group_name in group_list:
if group_name not in new_groups:

@ -25,8 +25,8 @@ import re
import tempfile
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleAction, _AnsibleActionDone, AnsibleActionFail
from ansible.module_utils.common.text.converters import to_native, to_text
from ansible.errors import AnsibleActionFail
from ansible.module_utils.common.text.converters import to_text
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.plugins.action import ActionBase
from ansible.utils.hashing import checksum_s
@ -83,7 +83,7 @@ class ActionModule(ActionBase):
self._supports_check_mode = False
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
if task_vars is None:
@ -104,13 +104,9 @@ class ActionModule(ActionBase):
if boolean(remote_src, strict=False):
# call assemble via ansible.legacy to allow library/ overrides of the module without collection search
result.update(self._execute_module(module_name='ansible.legacy.assemble', task_vars=task_vars))
raise _AnsibleActionDone()
else:
try:
src = self._find_needle('files', src)
except AnsibleError as e:
raise AnsibleActionFail(to_native(e))
return self._execute_module(module_name='ansible.legacy.assemble', task_vars=task_vars)
src = self._find_needle('files', src)
if not os.path.isdir(src):
raise AnsibleActionFail(u"Source (%s) is not a directory" % src)
@ -153,13 +149,9 @@ class ActionModule(ActionBase):
res = self._execute_module(module_name='ansible.legacy.copy', module_args=new_module_args, task_vars=task_vars)
if diff:
res['diff'] = diff
result.update(res)
return res
else:
result.update(self._execute_module(module_name='ansible.legacy.file', module_args=new_module_args, task_vars=task_vars))
return self._execute_module(module_name='ansible.legacy.file', module_args=new_module_args, task_vars=task_vars)
except AnsibleAction as e:
result.update(e.result)
finally:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -27,7 +27,7 @@ import tempfile
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleActionFail, AnsibleFileNotFound
from ansible.module_utils.basic import FILE_COMMON_ARGUMENTS
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.common.text.converters import to_bytes, to_text
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.plugins.action import ActionBase
from ansible.utils.hashing import checksum
@ -409,6 +409,7 @@ class ActionModule(ActionBase):
task_vars = dict()
result = super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
# ensure user is not setting internal parameters
@ -450,10 +451,10 @@ class ActionModule(ActionBase):
else:
content_tempfile = self._create_content_tempfile(content)
source = content_tempfile
except Exception as err:
result['failed'] = True
result['msg'] = "could not write content temp file: %s" % to_native(err)
return self._ensure_invocation(result)
except Exception as ex:
self._ensure_invocation(result)
raise AnsibleActionFail(message="could not write content temp file", result=result) from ex
# if we have first_available_file in our vars
# look up the files and use the first one we find as src
@ -470,9 +471,9 @@ class ActionModule(ActionBase):
# find in expected paths
source = self._find_needle('files', source)
except AnsibleError as ex:
result.update(self.result_dict_from_exception(ex))
self._ensure_invocation(result)
return self._ensure_invocation(result)
raise AnsibleActionFail(result=result) from ex
if trailing_slash != source.endswith(os.path.sep):
if source[-1] == os.path.sep:

@ -13,6 +13,7 @@ from ansible.executor.module_common import _apply_action_arg_defaults
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.plugins.action import ActionBase
from ansible.utils.vars import merge_hash
from ansible._internal._errors import _error_utils
class ActionModule(ActionBase):
@ -184,7 +185,7 @@ class ActionModule(ActionBase):
if failed:
result['failed_modules'] = failed
result.update(self._result_dict_from_captured_errors(
result.update(_error_utils.result_dict_from_captured_errors(
msg=f"The following modules failed to execute: {', '.join(failed.keys())}.",
errors=[r['exception'] for r in failed.values()],
))

@ -16,7 +16,7 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
from ansible.errors import AnsibleAction, AnsibleActionFail
from ansible.errors import AnsibleActionFail
from ansible.executor.module_common import _apply_action_arg_defaults
from ansible.module_utils.facts.system.pkg_mgr import PKG_MGRS
from ansible.plugins.action import ActionBase
@ -38,7 +38,7 @@ class ActionModule(ActionBase):
self._supports_check_mode = True
self._supports_async = True
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
module = self._task.args.get('use', 'auto')
@ -99,11 +99,8 @@ class ActionModule(ActionBase):
module = 'ansible.legacy.' + module
display.vvvv("Running %s" % module)
result.update(self._execute_module(module_name=module, module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val))
return self._execute_module(module_name=module, module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val)
else:
raise AnsibleActionFail('Could not detect which package manager to use. Try gathering facts or setting the "use" option.')
except AnsibleAction as e:
result.update(e.result)
return result
finally:
pass # avoid de-dent all on refactor

@ -21,7 +21,7 @@ import pathlib
import re
import shlex
from ansible.errors import AnsibleError, AnsibleAction, _AnsibleActionDone, AnsibleActionFail, AnsibleActionSkip
from ansible.errors import AnsibleError, AnsibleActionFail, AnsibleActionSkip
from ansible.executor.powershell import module_manifest as ps_manifest
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.plugins.action import ActionBase
@ -53,7 +53,7 @@ class ActionModule(ActionBase):
mutually_exclusive=[['_raw_params', 'cmd']],
)
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
try:
@ -105,16 +105,11 @@ class ActionModule(ActionBase):
# check mode is supported if 'creates' or 'removes' are provided
# the task has already been skipped if a change would not occur
if new_module_args['creates'] or new_module_args['removes']:
result['changed'] = True
raise _AnsibleActionDone(result=result)
return dict(changed=True)
# If the script doesn't return changed in the result, it defaults to True,
# but since the script may override 'changed', just skip instead of guessing.
else:
result['changed'] = False
raise AnsibleActionSkip('Check mode is not supported for this task.', result=result)
# now we execute script, always assume changed.
result['changed'] = True
raise AnsibleActionSkip('Check mode is not supported for this task.', result=dict(changed=False))
# transfer the file to a remote tmp location
tmp_src = self._connection._shell.join_path(self._connection._shell.tmpdir,
@ -168,14 +163,12 @@ class ActionModule(ActionBase):
# full manual exec of KEEP_REMOTE_FILES
script_cmd = self._connection._shell.build_module_command(env_string='', shebang='#!powershell', cmd='')
result.update(self._low_level_execute_command(cmd=script_cmd, in_data=exec_data, sudoable=True, chdir=chdir))
# now we execute script, always assume changed.
result = dict(self._low_level_execute_command(cmd=script_cmd, in_data=exec_data, sudoable=True, chdir=chdir), changed=True)
if 'rc' in result and result['rc'] != 0:
raise AnsibleActionFail('non-zero return code')
raise AnsibleActionFail('non-zero return code', result=result)
except AnsibleAction as e:
result.update(e.result)
return result
finally:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -16,7 +16,7 @@
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
from ansible.errors import AnsibleAction, AnsibleActionFail
from ansible.errors import AnsibleActionFail
from ansible.executor.module_common import _apply_action_arg_defaults
from ansible.plugins.action import ActionBase
@ -39,7 +39,7 @@ class ActionModule(ActionBase):
self._supports_check_mode = True
self._supports_async = True
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
module = self._task.args.get('use', 'auto').lower()
@ -84,14 +84,10 @@ class ActionModule(ActionBase):
module = 'ansible.legacy.' + module
self._display.vvvv("Running %s" % module)
result.update(self._execute_module(module_name=module, module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val))
return self._execute_module(module_name=module, module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val)
else:
raise AnsibleActionFail('Could not detect which service manager to use. Try gathering facts or setting the "use" option.')
except AnsibleAction as e:
result.update(e.result)
finally:
if not self._task.async_val:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -20,7 +20,7 @@ from jinja2.defaults import (
from ansible import constants as C
from ansible.config.manager import ensure_type
from ansible.errors import AnsibleError, AnsibleAction, AnsibleActionFail
from ansible.errors import AnsibleError, AnsibleActionFail
from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils.six import string_types
@ -39,7 +39,7 @@ class ActionModule(ActionBase):
if task_vars is None:
task_vars = dict()
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
# Options type validation
@ -167,13 +167,8 @@ class ActionModule(ActionBase):
loader=self._loader,
templar=self._templar,
shared_loader_obj=self._shared_loader_obj)
result.update(copy_action.run(task_vars=task_vars))
return copy_action.run(task_vars=task_vars)
finally:
shutil.rmtree(to_bytes(local_tempdir, errors='surrogate_or_strict'))
except AnsibleAction as e:
result.update(e.result)
finally:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -19,8 +19,7 @@ from __future__ import annotations
import os
from ansible.errors import AnsibleError, AnsibleAction, AnsibleActionFail, AnsibleActionSkip
from ansible.module_utils.common.text.converters import to_text
from ansible.errors import AnsibleActionFail, AnsibleActionSkip
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.plugins.action import ActionBase
@ -34,7 +33,7 @@ class ActionModule(ActionBase):
if task_vars is None:
task_vars = dict()
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
source = self._task.args.get('src', None)
@ -68,15 +67,9 @@ class ActionModule(ActionBase):
source = os.path.expanduser(source)
if not remote_src:
try:
source = self._loader.get_real_file(self._find_needle('files', source), decrypt=decrypt)
except AnsibleError as e:
raise AnsibleActionFail(to_text(e))
source = self._loader.get_real_file(self._find_needle('files', source), decrypt=decrypt)
try:
remote_stat = self._execute_remote_stat(dest, all_vars=task_vars, follow=True)
except AnsibleError as e:
raise AnsibleActionFail(to_text(e))
remote_stat = self._execute_remote_stat(dest, all_vars=task_vars, follow=True)
if not remote_stat['exists'] or not remote_stat['isdir']:
raise AnsibleActionFail("dest '%s' must be an existing dir" % dest)
@ -102,9 +95,6 @@ class ActionModule(ActionBase):
# execute the unarchive module now, with the updated args (using ansible.legacy prefix to eliminate collections
# collisions with local override
result.update(self._execute_module(module_name='ansible.legacy.unarchive', module_args=new_module_args, task_vars=task_vars))
except AnsibleAction as e:
result.update(e.result)
return self._execute_module(module_name='ansible.legacy.unarchive', module_args=new_module_args, task_vars=task_vars)
finally:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -5,11 +5,10 @@
from __future__ import annotations
import collections.abc as _c
import os
from ansible.errors import AnsibleError, AnsibleAction, _AnsibleActionDone, AnsibleActionFail
from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.collections import Mapping, MutableMapping
from ansible.errors import AnsibleActionFail
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.plugins.action import ActionBase
@ -25,7 +24,7 @@ class ActionModule(ActionBase):
if task_vars is None:
task_vars = dict()
result = super(ActionModule, self).run(tmp, task_vars)
super(ActionModule, self).run(tmp, task_vars)
del tmp # tmp no longer has any effect
body_format = self._task.args.get('body_format', 'raw')
@ -38,38 +37,31 @@ class ActionModule(ActionBase):
# everything is remote, so we just execute the module
# without changing any of the module arguments
# call with ansible.legacy prefix to prevent collections collisions while allowing local override
raise _AnsibleActionDone(result=self._execute_module(module_name='ansible.legacy.uri',
task_vars=task_vars, wrap_async=self._task.async_val))
return self._execute_module(module_name='ansible.legacy.uri', task_vars=task_vars, wrap_async=self._task.async_val)
kwargs = {}
if src:
try:
src = self._find_needle('files', src)
except AnsibleError as e:
raise AnsibleActionFail(to_native(e))
src = self._find_needle('files', src)
tmp_src = self._connection._shell.join_path(self._connection._shell.tmpdir, os.path.basename(src))
kwargs['src'] = tmp_src
self._transfer_file(src, tmp_src)
self._fixup_perms2((self._connection._shell.tmpdir, tmp_src))
elif body_format == 'form-multipart':
if not isinstance(body, Mapping):
if not isinstance(body, _c.Mapping):
raise AnsibleActionFail(
'body must be mapping, cannot be type %s' % body.__class__.__name__
)
for field, value in body.items():
if not isinstance(value, MutableMapping):
if not isinstance(value, _c.MutableMapping):
continue
content = value.get('content')
filename = value.get('filename')
if not filename or content:
continue
try:
filename = self._find_needle('files', filename)
except AnsibleError as e:
raise AnsibleActionFail(to_native(e))
filename = self._find_needle('files', filename)
tmp_src = self._connection._shell.join_path(
self._connection._shell.tmpdir,
@ -83,10 +75,7 @@ class ActionModule(ActionBase):
new_module_args = self._task.args | kwargs
# call with ansible.legacy prefix to prevent collections collisions while allowing local override
result.update(self._execute_module('ansible.legacy.uri', module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val))
except AnsibleAction as e:
result.update(e.result)
return self._execute_module('ansible.legacy.uri', module_args=new_module_args, task_vars=task_vars, wrap_async=self._task.async_val)
finally:
if not self._task.async_val:
self._remove_tmp_path(self._connection._shell.tmpdir)
return result

@ -51,7 +51,7 @@ from struct import unpack, pack
from ansible import constants as C
from ansible.constants import config
from ansible.errors import AnsibleAssertionError, AnsiblePromptInterrupt, AnsiblePromptNoninteractive, AnsibleError
from ansible._internal._errors import _utils, _error_factory
from ansible._internal._errors import _error_utils, _error_factory
from ansible._internal import _event_formatting
from ansible.module_utils._internal import _ambient_context, _deprecator, _messages
from ansible.module_utils.common.text.converters import to_bytes, to_text
@ -731,7 +731,7 @@ class Display(metaclass=Singleton):
raise AnsibleError(formatted_msg)
if source_context := _utils.SourceContext.from_value(obj):
if source_context := _error_utils.SourceContext.from_value(obj):
formatted_source_context = str(source_context)
else:
formatted_source_context = None
@ -791,7 +791,7 @@ class Display(metaclass=Singleton):
# This is the pre-proxy half of the `warning` implementation.
# Any logic that must occur on workers needs to be implemented here.
if source_context := _utils.SourceContext.from_value(obj):
if source_context := _error_utils.SourceContext.from_value(obj):
formatted_source_context = str(source_context)
else:
formatted_source_context = None
@ -877,15 +877,29 @@ class Display(metaclass=Singleton):
(out, err) = cmd.communicate()
self.display(u"%s\n" % to_text(out), color=color)
def error_as_warning(self, msg: str | None, exception: BaseException) -> None:
def error_as_warning(
self,
msg: str | None,
exception: BaseException,
*,
help_text: str | None = None,
obj: t.Any = None,
) -> None:
"""Display an exception as a warning."""
_skip_stackwalk = True
event = _error_factory.ControllerEventFactory.from_exception(exception, _traceback.is_traceback_enabled(_traceback.TracebackEvent.WARNING))
if msg:
if source_context := _error_utils.SourceContext.from_value(obj):
formatted_source_context = str(source_context)
else:
formatted_source_context = None
event = _messages.Event(
msg=msg,
help_text=help_text,
formatted_source_context=formatted_source_context,
formatted_traceback=_traceback.maybe_capture_traceback(msg, _traceback.TracebackEvent.WARNING),
chain=_messages.EventChain(
msg_reason=_errors.MSG_REASON_DIRECT_CAUSE,

@ -3,7 +3,7 @@
set -eux
# run type tests
ansible -a 'sleep 5' --task-timeout 1 localhost |grep 'The command action failed to execute in the expected time frame (1) and was terminated'
ansible -a 'sleep 5' --task-timeout 1 localhost |grep 'Timed out after'
# -a parsing with json
ansible --task-timeout 5 localhost -m command -a '{"cmd": "whoami"}' | grep 'rc=0'

@ -36,7 +36,7 @@
- assert:
that:
- incompatible.failed
- not incompatible.msg.startswith("The command action failed to execute in the expected time frame")
- not incompatible.msg is contains 'Timed out after'
- '"Failed to resolve the requested dependencies map" in incompatible.stderr'
- '"* namespace1.name1:1.0.9 (direct request)" in incompatible.stderr'
- '"* namespace1.name1:0.0.5 (dependency of ns.coll:1.0.0)" in incompatible.stderr'

@ -41,4 +41,4 @@
- assert:
that:
- timeout_cmd.msg == 'The win_shell action failed to execute in the expected time frame (5) and was terminated'
- timeout_cmd.msg is contains 'Timed out after'

@ -40,7 +40,7 @@
- assert:
that:
- timeout_cmd.msg == 'The win_shell action failed to execute in the expected time frame (5) and was terminated'
- timeout_cmd.msg is contains 'Timed out after'
- name: Test WinRM HTTP connection
win_ping:

@ -9,5 +9,5 @@
- assert:
that:
- time is failed
- '"The shell action failed to execute in the expected time frame" in time["msg"]'
- time.msg is contains 'Timed out after'
- '"timedout" in time'

@ -0,0 +1,2 @@
shippable/posix/group4
context/controller

@ -0,0 +1,61 @@
- name: run a task which times out
command: sleep 10
timeout: 1
register: result
ignore_errors: yes
- name: verify the task timed out
assert:
that:
- result is failed
- result is timedout
- result.timedout.period == 1
- result.msg is contains "Timed out after 1 second"
- name: run a task with a negative timeout
command: sleep 3
timeout: -1
register: result
ignore_errors: yes
- name: verify the task failed
assert:
that:
- result is failed
- result is not timedout
- result.msg is contains "Timeout -1 is invalid"
- name: run a task with a timeout that is too large
command: sleep 3
timeout: 100000001
register: result
ignore_errors: yes
- name: verify the task failed
assert:
that:
- result is failed
- result is not timedout
- result.msg is contains "Timeout 100000001 is invalid"
- name: run a task with a zero timeout
command: sleep 3
timeout: 0
register: result
- name: verify the task did not time out
assert:
that:
- result is not timedout
- result.delta is search '^0:00:0[3-9]\.' # delta must be between 3 and 9 seconds
- name: run a task with a large timeout that is not triggered
command: sleep 3
timeout: 100000000
register: result
- name: verify the task did not time out
assert:
that:
- result is not timedout
- result.delta is search '^0:00:0[3-9]\.' # delta must be between 3 and 9 seconds

@ -0,0 +1,123 @@
from __future__ import annotations
import contextlib
import signal
import time
import typing as t
import pytest
from ansible._internal._errors import _alarm_timeout
from ansible._internal._errors._alarm_timeout import AnsibleTimeoutError
pytestmark = pytest.mark.usefixtures("assert_sigalrm_state")
@pytest.fixture
def assert_sigalrm_state() -> t.Iterator[None]:
"""Fixture to ensure that SIGALRM state is as-expected before and after each test."""
assert signal.alarm(0) == 0 # disable alarm before resetting the default handler
assert signal.signal(signal.SIGALRM, signal.SIG_DFL) == signal.SIG_DFL
try:
yield
finally:
assert signal.alarm(0) == 0
assert signal.signal(signal.SIGALRM, signal.SIG_DFL) == signal.SIG_DFL
@pytest.mark.parametrize("timeout", (0, 1, None))
def test_alarm_timeout_success(timeout: int | None) -> None:
"""Validate a non-timeout success scenario."""
ran = False
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(timeout):
time.sleep(0.01)
ran = True
assert ran
def test_alarm_timeout_timeout() -> None:
"""Validate a happy-path timeout scenario."""
ran = False
timeout_sec = 1
with pytest.raises(AnsibleTimeoutError) as error:
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(timeout_sec):
time.sleep(timeout_sec + 1)
ran = True # pragma: nocover
assert not ran
assert error.value.timeout == timeout_sec
@pytest.mark.parametrize("timeout,expected_error_type,expected_error_pattern", (
(-1, ValueError, "Timeout.*invalid.*between"),
(100_000_001, ValueError, "Timeout.*invalid.*between"),
(0.1, TypeError, "requires 'int' argument.*'float'"),
("1", TypeError, "requires 'int' argument.*'str'"),
))
def test_alarm_timeout_bad_values(timeout: t.Any, expected_error_type: type[Exception], expected_error_pattern: str) -> None:
"""Validate behavior for invalid inputs."""
ran = False
with pytest.raises(expected_error_type, match=expected_error_pattern):
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(timeout):
ran = True # pragma: nocover
assert not ran
def test_alarm_timeout_bad_state() -> None:
"""Validate alarm state error handling."""
def call_it():
ran = False
with pytest.raises(RuntimeError, match="existing alarm"):
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(1):
ran = True # pragma: nocover
assert not ran
try:
# non-default SIGALRM handler present
signal.signal(signal.SIGALRM, lambda _s, _f: None)
call_it()
finally:
signal.signal(signal.SIGALRM, signal.SIG_DFL)
try:
# alarm already set
signal.alarm(10000)
call_it()
finally:
signal.signal(signal.SIGALRM, signal.SIG_DFL)
ran_outer = ran_inner = False
# nested alarm_timeouts
with pytest.raises(RuntimeError, match="existing alarm"):
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(1):
ran_outer = True
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(1):
ran_inner = True # pragma: nocover
assert not ran_inner
assert ran_outer
def test_alarm_timeout_raise():
"""Ensure that an exception raised in the wrapped scope propagates correctly."""
with pytest.raises(NotImplementedError):
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(1):
raise NotImplementedError()
def test_alarm_timeout_escape_broad_exception():
"""Ensure that the timeout exception can escape a broad exception handler in the wrapped scope."""
with pytest.raises(AnsibleTimeoutError):
with _alarm_timeout.AnsibleTimeoutError.alarm_timeout(1):
with contextlib.suppress(Exception):
time.sleep(3)

@ -0,0 +1,64 @@
from __future__ import annotations
import collections.abc as c
import typing as t
import pytest
from ansible._internal._errors import _error_utils
from ansible.module_utils._internal import _messages
from units.mock.error_helper import raise_exceptions
class _TestContributesError(Exception, _error_utils.ContributesToTaskResult):
@property
def result_contribution(self) -> c.Mapping[str, object]:
return dict(some_flag=True)
class _TestContributesUnreachable(Exception, _error_utils.ContributesToTaskResult):
@property
def omit_failed_key(self) -> bool:
return True
@property
def result_contribution(self) -> c.Mapping[str, object]:
return dict(unreachable=True)
class _TestContributesMsg(Exception, _error_utils.ContributesToTaskResult):
@property
def result_contribution(self) -> c.Mapping[str, object]:
return dict(msg="contributed msg")
@pytest.mark.parametrize("exceptions,expected", (
(
(Exception("e0"), _TestContributesError("e1"), ValueError("e2")),
dict(failed=True, some_flag=True, msg="e0: e1: e2"),
),
(
(Exception("e0"), ValueError("e1"), _TestContributesError("e2")),
dict(failed=True, some_flag=True, msg="e0: e1: e2"),
),
(
(Exception("e0"), _TestContributesUnreachable("e1")),
dict(unreachable=True, msg="e0: e1"),
),
(
(Exception("e0"), _TestContributesMsg()),
dict(failed=True, msg="contributed msg"),
),
))
def test_exception_result_contribution(exceptions: t.Sequence[BaseException], expected: dict[str, t.Any]) -> None:
"""Validate result dict augmentation by exceptions conforming to the ContributeToTaskResult protocol."""
with pytest.raises(Exception) as error:
raise_exceptions(exceptions)
result = _error_utils.result_dict_from_exception(error.value, accept_result_contribution=True)
summary = result.pop('exception')
assert isinstance(summary, _messages.ErrorSummary)
assert result == expected

@ -0,0 +1,27 @@
from __future__ import annotations
from ansible._internal._errors._task_timeout import TaskTimeoutError
from ansible.module_utils._internal._datatag._tags import Deprecated
def test_task_timeout_result_contribution() -> None:
"""Validate the result contribution shape."""
try:
raise TaskTimeoutError(99)
except TaskTimeoutError as tte:
contrib = tte.result_contribution
assert isinstance(contrib, dict)
timedout = contrib.get('timedout')
assert isinstance(timedout, dict)
frame = timedout.get('frame')
assert isinstance(frame, str)
assert Deprecated.is_tagged_on(frame)
period = timedout.get('period')
assert period == 99

@ -4,26 +4,23 @@ import traceback
from ansible._internal._errors import _error_factory
from ansible._internal._event_formatting import format_event_traceback
from units.mock.error_helper import raise_exceptions
import pytest
def test_traceback_formatting() -> None:
"""Verify our traceback formatting mimics the Python traceback formatting."""
try:
try:
try:
try:
raise Exception('one')
except Exception as ex:
raise Exception('two') from ex
except Exception:
raise Exception('three')
except Exception as ex:
raise Exception('four') from ex
except Exception as ex:
saved_ex = ex
with pytest.raises(Exception) as error:
raise_exceptions((
Exception('a'),
Exception('b'),
Exception('c'),
Exception('d'),
))
event = _error_factory.ControllerEventFactory.from_exception(saved_ex, True) # pylint: disable=used-before-assignment
event = _error_factory.ControllerEventFactory.from_exception(error.value, True)
ansible_tb = format_event_traceback(event)
python_tb = ''.join(traceback.format_exception(saved_ex))
python_tb = ''.join(traceback.format_exception(error.value))
assert ansible_tb == python_tb

@ -5,7 +5,7 @@ import pathlib
import pytest
from ansible.errors import AnsibleError, AnsibleVariableTypeError
from ansible._internal._errors._utils import SourceContext
from ansible._internal._errors._error_utils import SourceContext
from ansible._internal._datatag._tags import Origin
from ..test_utils.controller.display import emits_warnings

@ -6,19 +6,10 @@ from ansible._internal._errors import _error_factory
from ansible.errors import AnsibleError
from ansible._internal._datatag._tags import Origin
from ansible._internal._errors._utils import format_exception_message
from ansible._internal._errors._error_utils import format_exception_message
from ansible.utils.display import _format_message
from ansible.module_utils._internal import _messages
def raise_exceptions(exceptions: list[BaseException]) -> None:
if len(exceptions) > 1:
try:
raise_exceptions(exceptions[1:])
except Exception as ex:
raise exceptions[0] from ex
raise exceptions[0]
from units.mock.error_helper import raise_exceptions
_shared_cause = Exception('shared cause')

@ -0,0 +1,17 @@
from __future__ import annotations
import collections.abc as c
def raise_exceptions(exceptions: c.Sequence[BaseException]) -> None:
"""
Raise a chain of exceptions from the given exception list.
Exceptions will be raised starting from the end of the list.
"""
if len(exceptions) > 1:
try:
raise_exceptions(exceptions[1:])
except Exception as ex:
raise exceptions[0] from ex
raise exceptions[0]

@ -6,7 +6,7 @@ import tempfile
import pytest
from ansible.errors import AnsibleJSONParserError
from ansible._internal._errors._utils import format_exception_message
from ansible._internal._errors._error_utils import format_exception_message
from ansible._internal._datatag._tags import Origin
from ansible.parsing.utils.yaml import from_yaml

@ -11,7 +11,7 @@ import pytest
import pytest_mock
from ansible import constants as C
from ansible._internal._errors._utils import format_exception_message
from ansible._internal._errors._error_utils import format_exception_message
from ansible._internal._datatag._tags import Origin
from ansible.parsing.utils.yaml import from_yaml
from ansible._internal._yaml._errors import AnsibleYAMLParserError

@ -0,0 +1,21 @@
from __future__ import annotations
import pytest
from ansible import errors
from units.test_utils.controller.display import emits_warnings
@pytest.mark.parametrize("name", (
"AnsibleFilterTypeError",
"_AnsibleActionDone",
))
def test_deprecated(name: str) -> None:
with emits_warnings(deprecation_pattern='is deprecated'):
getattr(errors, name)
def test_deprecated_attribute_error() -> None:
with pytest.raises(AttributeError):
getattr(errors, 'bogus')
Loading…
Cancel
Save