diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index c11dca8795a..43a79dd3d35 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -2605,7 +2605,7 @@ class AnsibleModule(object): def run_command(self, args, check_rc=False, close_fds=True, executable=None, data=None, binary_data=False, path_prefix=None, cwd=None, use_unsafe_shell=False, prompt_regex=None, environ_update=None, umask=None, encoding='utf-8', errors='surrogate_or_strict', - expand_user_and_vars=True, pass_fds=None, before_communicate_callback=None, raise_timeouts=False): + expand_user_and_vars=True, pass_fds=None, before_communicate_callback=None): ''' Execute a command, returns rc, stdout, and stderr. @@ -2655,9 +2655,6 @@ class AnsibleModule(object): after ``Popen`` object will be created but before communicating to the process. (``Popen`` object will be passed to callback as a first argument) - :kw raise_timeouts: This is a boolean, which when True, will allow the - caller to deal with timeout exceptions. When false we use the previous - behaviour of having run_command directly call fail_json when they occur. :returns: A 3-tuple of return code (integer), stdout (native string), and stderr (native string). On python2, stdout and stderr are both byte strings. On python3, stdout and stderr are text strings converted @@ -2831,12 +2828,6 @@ class AnsibleModule(object): cmd.stderr.close() rc = cmd.returncode - except TimeoutError as e: - self.log("Timeout Executing CMD:%s Timeout :%s" % (self._clean_args(args), to_native(e))) - if raise_timeouts: - raise e - else: - self.fail_json(rc=e.errno, msg=to_native(e), cmd=self._clean_args(args)) except (OSError, IOError) as e: self.log("Error Executing CMD:%s Exception:%s" % (self._clean_args(args), to_native(e))) self.fail_json(rc=e.errno, msg=to_native(e), cmd=self._clean_args(args)) diff --git a/lib/ansible/module_utils/facts/timeout.py b/lib/ansible/module_utils/facts/timeout.py index 2927b31c822..934e7aff683 100644 --- a/lib/ansible/module_utils/facts/timeout.py +++ b/lib/ansible/module_utils/facts/timeout.py @@ -16,7 +16,8 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import signal +import multiprocessing +import multiprocessing.pool as mp # timeout function to make sure some fact gathering # steps do not exceed a time limit @@ -30,24 +31,25 @@ class TimeoutError(Exception): def timeout(seconds=None, error_message="Timer expired"): - + """ + Timeout decorator to expire after a set number of seconds. This raises an + ansible.module_utils.facts.TimeoutError if the timeout is hit before the + function completes. + """ def decorator(func): - def _handle_timeout(signum, frame): - msg = 'Timer expired after %s seconds' % globals().get('GATHER_TIMEOUT') - raise TimeoutError(msg) - def wrapper(*args, **kwargs): - local_seconds = seconds - if local_seconds is None: - local_seconds = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT - signal.signal(signal.SIGALRM, _handle_timeout) - signal.alarm(local_seconds) + timeout_value = seconds + if timeout_value is None: + timeout_value = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT + pool = mp.ThreadPool(processes=1) + res = pool.apply_async(func, args, kwargs) + pool.close() try: - result = func(*args, **kwargs) - finally: - signal.alarm(0) - return result + return res.get(timeout_value) + except multiprocessing.TimeoutError: + # This is an ansible.module_utils.common.facts.timeout.TimeoutError + raise TimeoutError('Timer expired after %s seconds' % timeout_value) return wrapper diff --git a/test/units/module_utils/facts/test_timeout.py b/test/units/module_utils/facts/test_timeout.py index 36adbfabd19..f54fcf141a2 100644 --- a/test/units/module_utils/facts/test_timeout.py +++ b/test/units/module_utils/facts/test_timeout.py @@ -20,13 +20,11 @@ from __future__ import (absolute_import, division) __metaclass__ = type +import sys import time import pytest -from units.compat import unittest -from units.compat.mock import patch, MagicMock - from ansible.module_utils.facts import timeout @@ -67,6 +65,10 @@ def sleep_amount_explicit_lower(amount): return 'Succeeded after {0} sec'.format(amount) +# +# Tests for how the timeout decorator is specified +# + def test_defaults_still_within_bounds(): # If the default changes outside of these bounds, some of the tests will # no longer test the right thing. Need to review and update the timeouts @@ -110,3 +112,58 @@ def test_explicit_timeout(): sleep_time = 3 with pytest.raises(timeout.TimeoutError): assert sleep_amount_explicit_lower(sleep_time) == '(Not expected to succeed)' + + +# +# Test that exception handling works +# + +@timeout.timeout(1) +def function_times_out(): + time.sleep(2) + + +# This is just about the same test as function_times_out but uses a separate process which is where +# we normally have our timeouts. It's more of an integration test than a unit test. +@timeout.timeout(1) +def function_times_out_in_run_command(am): + am.run_command([sys.executable, '-c', 'import time ; time.sleep(2)']) + + +@timeout.timeout(1) +def function_other_timeout(): + raise TimeoutError('Vanilla Timeout') + + +@timeout.timeout(1) +def function_raises(): + 1 / 0 + + +@timeout.timeout(1) +def function_catches_all_exceptions(): + try: + time.sleep(10) + except BaseException: + raise RuntimeError('We should not have gotten here') + + +def test_timeout_raises_timeout(): + with pytest.raises(timeout.TimeoutError): + assert function_times_out() == '(Not expected to succeed)' + + +@pytest.mark.parametrize('stdin', ({},), indirect=['stdin']) +def test_timeout_raises_timeout_integration_test(am): + with pytest.raises(timeout.TimeoutError): + assert function_times_out_in_run_command(am) == '(Not expected to succeed)' + + +def test_timeout_raises_other_exception(): + with pytest.raises(ZeroDivisionError): + assert function_raises() == '(Not expected to succeed)' + + +def test_exception_not_caught_by_called_code(): + with pytest.raises(timeout.TimeoutError): + assert function_catches_all_exceptions() == '(Not expected to succeed)'