diff --git a/lib/ansible/module_utils/facts/timeout.py b/lib/ansible/module_utils/facts/timeout.py index 2927b31c822..934e7aff683 100644 --- a/lib/ansible/module_utils/facts/timeout.py +++ b/lib/ansible/module_utils/facts/timeout.py @@ -16,7 +16,8 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type -import signal +import multiprocessing +import multiprocessing.pool as mp # timeout function to make sure some fact gathering # steps do not exceed a time limit @@ -30,24 +31,25 @@ class TimeoutError(Exception): def timeout(seconds=None, error_message="Timer expired"): - + """ + Timeout decorator to expire after a set number of seconds. This raises an + ansible.module_utils.facts.TimeoutError if the timeout is hit before the + function completes. + """ def decorator(func): - def _handle_timeout(signum, frame): - msg = 'Timer expired after %s seconds' % globals().get('GATHER_TIMEOUT') - raise TimeoutError(msg) - def wrapper(*args, **kwargs): - local_seconds = seconds - if local_seconds is None: - local_seconds = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT - signal.signal(signal.SIGALRM, _handle_timeout) - signal.alarm(local_seconds) + timeout_value = seconds + if timeout_value is None: + timeout_value = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT + pool = mp.ThreadPool(processes=1) + res = pool.apply_async(func, args, kwargs) + pool.close() try: - result = func(*args, **kwargs) - finally: - signal.alarm(0) - return result + return res.get(timeout_value) + except multiprocessing.TimeoutError: + # This is an ansible.module_utils.common.facts.timeout.TimeoutError + raise TimeoutError('Timer expired after %s seconds' % timeout_value) return wrapper diff --git a/test/units/module_utils/facts/test_timeout.py b/test/units/module_utils/facts/test_timeout.py index bd18655f29d..f54fcf141a2 100644 --- a/test/units/module_utils/facts/test_timeout.py +++ b/test/units/module_utils/facts/test_timeout.py @@ -20,13 +20,11 @@ from __future__ import (absolute_import, division) __metaclass__ = type +import sys import time import pytest -from ansible.compat.tests import unittest -from ansible.compat.tests.mock import patch, MagicMock - from ansible.module_utils.facts import timeout @@ -67,6 +65,10 @@ def sleep_amount_explicit_lower(amount): return 'Succeeded after {0} sec'.format(amount) +# +# Tests for how the timeout decorator is specified +# + def test_defaults_still_within_bounds(): # If the default changes outside of these bounds, some of the tests will # no longer test the right thing. Need to review and update the timeouts @@ -110,3 +112,58 @@ def test_explicit_timeout(): sleep_time = 3 with pytest.raises(timeout.TimeoutError): assert sleep_amount_explicit_lower(sleep_time) == '(Not expected to succeed)' + + +# +# Test that exception handling works +# + +@timeout.timeout(1) +def function_times_out(): + time.sleep(2) + + +# This is just about the same test as function_times_out but uses a separate process which is where +# we normally have our timeouts. It's more of an integration test than a unit test. +@timeout.timeout(1) +def function_times_out_in_run_command(am): + am.run_command([sys.executable, '-c', 'import time ; time.sleep(2)']) + + +@timeout.timeout(1) +def function_other_timeout(): + raise TimeoutError('Vanilla Timeout') + + +@timeout.timeout(1) +def function_raises(): + 1 / 0 + + +@timeout.timeout(1) +def function_catches_all_exceptions(): + try: + time.sleep(10) + except BaseException: + raise RuntimeError('We should not have gotten here') + + +def test_timeout_raises_timeout(): + with pytest.raises(timeout.TimeoutError): + assert function_times_out() == '(Not expected to succeed)' + + +@pytest.mark.parametrize('stdin', ({},), indirect=['stdin']) +def test_timeout_raises_timeout_integration_test(am): + with pytest.raises(timeout.TimeoutError): + assert function_times_out_in_run_command(am) == '(Not expected to succeed)' + + +def test_timeout_raises_other_exception(): + with pytest.raises(ZeroDivisionError): + assert function_raises() == '(Not expected to succeed)' + + +def test_exception_not_caught_by_called_code(): + with pytest.raises(timeout.TimeoutError): + assert function_catches_all_exceptions() == '(Not expected to succeed)'