Fix @contextmanager leak on exception. (#21031)

* Fix @contextmanager leak on exception.
* Fix test leaks of global module args cache.
pull/20743/head
Matt Clay 8 years ago committed by GitHub
parent bb9ee0cf6f
commit 272ff10fa1

@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()):
context manager that temporarily masks the test runner's values for stdin and argv context manager that temporarily masks the test runner's values for stdin and argv
""" """
real_stdin = sys.stdin real_stdin = sys.stdin
real_argv = sys.argv
if PY3: if PY3:
sys.stdin = StringIO(stdin_data) fake_stream = StringIO(stdin_data)
sys.stdin.buffer = BytesIO(to_bytes(stdin_data)) fake_stream.buffer = BytesIO(to_bytes(stdin_data))
else: else:
sys.stdin = BytesIO(to_bytes(stdin_data)) fake_stream = BytesIO(to_bytes(stdin_data))
real_argv = sys.argv try:
sys.argv = argv_data sys.stdin = fake_stream
yield sys.argv = argv_data
sys.stdin = real_stdin
sys.argv = real_argv yield
finally:
sys.stdin = real_stdin
sys.argv = real_argv
@contextmanager @contextmanager
@ -56,13 +60,18 @@ def swap_stdout():
context manager that temporarily replaces stdout for tests that need to verify output context manager that temporarily replaces stdout for tests that need to verify output
""" """
old_stdout = sys.stdout old_stdout = sys.stdout
if PY3: if PY3:
fake_stream = StringIO() fake_stream = StringIO()
else: else:
fake_stream = BytesIO() fake_stream = BytesIO()
sys.stdout = fake_stream
yield fake_stream try:
sys.stdout = old_stdout sys.stdout = fake_stream
yield fake_stream
finally:
sys.stdout = old_stdout
class ModuleTestCase(unittest.TestCase): class ModuleTestCase(unittest.TestCase):

@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
from ansible.module_utils import basic from ansible.module_utils import basic
# test basic log invocation # test basic log invocation
basic._ANSIBLE_ARGS = None
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec=dict( argument_spec=dict(
foo = dict(default=True, type='bool'), foo = dict(default=True, type='bool'),

@ -311,6 +311,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
basic._ANSIBLE_ARGS = None
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -327,6 +328,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
basic._ANSIBLE_ARGS = None
self.assertRaises( self.assertRaises(
SystemExit, SystemExit,
basic.AnsibleModule, basic.AnsibleModule,
@ -580,12 +582,11 @@ class TestModuleUtilsBasic(ModuleTestCase):
def test_module_utils_basic_ansible_module_is_special_selinux_path(self): def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
from ansible.module_utils import basic from ansible.module_utils import basic
basic._ANSIBLE_ARGS = None
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
basic._ANSIBLE_ARGS = None
am = basic.AnsibleModule( am = basic.AnsibleModule(
argument_spec = dict(), argument_spec = dict(),
) )

@ -721,6 +721,7 @@ def test_distribution_version():
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={})) args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}))
with swap_stdin_and_argv(stdin_data=args): with swap_stdin_and_argv(stdin_data=args):
basic._ANSIBLE_ARGS = None
module = basic.AnsibleModule(argument_spec=dict()) module = basic.AnsibleModule(argument_spec=dict())
for t in TESTSETS: for t in TESTSETS:

Loading…
Cancel
Save