diff --git a/test/units/mock/procenv.py b/test/units/mock/procenv.py index e9d470c0796..6cf69a7acc5 100644 --- a/test/units/mock/procenv.py +++ b/test/units/mock/procenv.py @@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()): context manager that temporarily masks the test runner's values for stdin and argv """ real_stdin = sys.stdin + real_argv = sys.argv if PY3: - sys.stdin = StringIO(stdin_data) - sys.stdin.buffer = BytesIO(to_bytes(stdin_data)) + fake_stream = StringIO(stdin_data) + fake_stream.buffer = BytesIO(to_bytes(stdin_data)) else: - sys.stdin = BytesIO(to_bytes(stdin_data)) + fake_stream = BytesIO(to_bytes(stdin_data)) - real_argv = sys.argv - sys.argv = argv_data - yield - sys.stdin = real_stdin - sys.argv = real_argv + try: + sys.stdin = fake_stream + sys.argv = argv_data + + yield + finally: + sys.stdin = real_stdin + sys.argv = real_argv @contextmanager @@ -56,13 +60,18 @@ def swap_stdout(): context manager that temporarily replaces stdout for tests that need to verify output """ old_stdout = sys.stdout + if PY3: fake_stream = StringIO() else: fake_stream = BytesIO() - sys.stdout = fake_stream - yield fake_stream - sys.stdout = old_stdout + + try: + sys.stdout = fake_stream + + yield fake_stream + finally: + sys.stdout = old_stdout class ModuleTestCase(unittest.TestCase): diff --git a/test/units/module_utils/basic/test__log_invocation.py b/test/units/module_utils/basic/test__log_invocation.py index d4510c5efc5..3723697bedc 100644 --- a/test/units/module_utils/basic/test__log_invocation.py +++ b/test/units/module_utils/basic/test__log_invocation.py @@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic # test basic log invocation + basic._ANSIBLE_ARGS = None am = basic.AnsibleModule( argument_spec=dict( foo = dict(default=True, type='bool'), diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index b8551cb722a..2d3e6717e08 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -311,6 +311,7 @@ class TestModuleUtilsBasic(ModuleTestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None self.assertRaises( SystemExit, basic.AnsibleModule, @@ -327,6 +328,7 @@ class TestModuleUtilsBasic(ModuleTestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None self.assertRaises( SystemExit, basic.AnsibleModule, @@ -580,12 +582,11 @@ class TestModuleUtilsBasic(ModuleTestCase): def test_module_utils_basic_ansible_module_is_special_selinux_path(self): from ansible.module_utils import basic - basic._ANSIBLE_ARGS = None args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"})) with swap_stdin_and_argv(stdin_data=args): - + basic._ANSIBLE_ARGS = None am = basic.AnsibleModule( argument_spec = dict(), ) diff --git a/test/units/module_utils/test_distribution_version.py b/test/units/module_utils/test_distribution_version.py index 388c1d5ca3c..75304514f17 100644 --- a/test/units/module_utils/test_distribution_version.py +++ b/test/units/module_utils/test_distribution_version.py @@ -721,6 +721,7 @@ def test_distribution_version(): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None module = basic.AnsibleModule(argument_spec=dict()) for t in TESTSETS: