diff --git a/lib/ansible/executor/module_common.py b/lib/ansible/executor/module_common.py index fb9eb04d4cb..5d85bbec284 100644 --- a/lib/ansible/executor/module_common.py +++ b/lib/ansible/executor/module_common.py @@ -142,6 +142,7 @@ def debug(command, zipped_mod, json_params): # Okay to use __file__ here because we're running from a kept file basedir = os.path.abspath(os.path.dirname(__file__)) + args_path = os.path.join(basedir, 'args') if command == 'explode': # transform the ZIPDATA into an exploded directory of code and then # print the path to the code. This is an easy way for people to look @@ -163,6 +164,11 @@ def debug(command, zipped_mod, json_params): f.write(z.read(filename)) f.close() + # write the args file + f = open(args_path, 'w') + f.write(json_params) + f.close() + print('Module expanded into:') print('%%s' %% os.path.join(basedir, 'ansible')) exitcode = 0 @@ -171,7 +177,29 @@ def debug(command, zipped_mod, json_params): # Execute the exploded code instead of executing the module from the # embedded ZIPDATA. This allows people to easily run their modified # code on the remote machine to see how changes will affect it. - exitcode = invoke_module(os.path.join(basedir, 'ansible_module_%(ansible_module)s.py'), basedir, json_params) + # This differs slightly from default Ansible execution of Python modules + # as it passes the arguments to the module via a file instead of stdin. + + pythonpath = os.environ.get('PYTHONPATH') + if pythonpath: + os.environ['PYTHONPATH'] = ':'.join((basedir, pythonpath)) + else: + os.environ['PYTHONPATH'] = basedir + + p = subprocess.Popen(['%(interpreter)s', 'ansible_module_%(ansible_module)s.py', args_path], env=os.environ, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) + (stdout, stderr) = p.communicate() + + if not isinstance(stderr, (bytes, unicode)): + stderr = stderr.read() + if not isinstance(stdout, (bytes, unicode)): + stdout = stdout.read() + if PY3: + sys.stderr.buffer.write(stderr) + sys.stdout.buffer.write(stdout) + else: + sys.stderr.write(stderr) + sys.stdout.write(stdout) + return p.returncode elif command == 'excommunicate': # This attempts to run the module in-process (by importing a main @@ -182,8 +210,9 @@ def debug(command, zipped_mod, json_params): # when using this that are only artifacts of how we're invoking here, # not actual bugs (as they don't affect the real way that we invoke # ansible modules) - sys.stdin = IOStream(json_params) - sys.path.insert(0, basedir) + + # stub the + sys.argv = ['%(ansible_module)s', args_path] from ansible_module_%(ansible_module)s import main main() print('WARNING: Module returned to wrapper instead of exiting') diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 983fcb7ec6d..60f38708748 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -1435,11 +1435,23 @@ class AnsibleModule(object): def _load_params(self): ''' read the input and set the params attribute. Sets the constants as well.''' + # debug overrides to read args from file or cmdline + # Avoid tracebacks when locale is non-utf8 - if sys.version_info < (3,): - buffer = sys.stdin.read() + if len(sys.argv) > 1: + if os.path.isfile(sys.argv[1]): + fd = open(sys.argv[1], 'rb') + buffer = fd.read() + fd.close() + else: + buffer = sys.argv[1] + # default case, read from stdin else: - buffer = sys.stdin.buffer.read() + if sys.version_info < (3,): + buffer = sys.stdin.read() + else: + buffer = sys.stdin.buffer.read() + try: params = json.loads(buffer.decode('utf-8')) except ValueError: diff --git a/test/units/mock/procenv.py b/test/units/mock/procenv.py new file mode 100644 index 00000000000..ae0ea5abf5e --- /dev/null +++ b/test/units/mock/procenv.py @@ -0,0 +1,57 @@ +# (c) 2016, Matt Davis +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys + +from contextlib import contextmanager +from io import BytesIO, StringIO +from ansible.compat.six import PY3 +from ansible.utils.unicode import to_bytes + +@contextmanager +def swap_stdin_and_argv(stdin_data='', argv_data=tuple()): + """ + context manager that temporarily masks the test runner's values for stdin and argv + """ + real_stdin = sys.stdin + + if PY3: + sys.stdin = StringIO(stdin_data) + sys.stdin.buffer = BytesIO(to_bytes(stdin_data)) + else: + sys.stdin = BytesIO(to_bytes(stdin_data)) + + real_argv = sys.argv + sys.argv = argv_data + yield + sys.stdin = real_stdin + sys.argv = real_argv + +@contextmanager +def swap_stdout(): + """ + context manager that temporarily replaces stdout for tests that need to verify output + """ + old_stdout = sys.stdout + fake_stream = BytesIO() + sys.stdout = fake_stream + yield fake_stream + sys.stdout = old_stdout \ No newline at end of file diff --git a/test/units/module_utils/basic/test__log_invocation.py b/test/units/module_utils/basic/test__log_invocation.py index 34037f963c9..677eaa2c90c 100644 --- a/test/units/module_utils/basic/test__log_invocation.py +++ b/test/units/module_utils/basic/test__log_invocation.py @@ -22,75 +22,62 @@ __metaclass__ = type import sys import json -from io import BytesIO, StringIO -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes +from units.mock.procenv import swap_stdin_and_argv from ansible.compat.tests import unittest from ansible.compat.tests.mock import MagicMock + class TestModuleUtilsBasic(unittest.TestCase): - def setUp(self): - self.real_stdin = sys.stdin - args = json.dumps( + @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") + def test_module_utils_basic__log_invocation(self): + with swap_stdin_and_argv(stdin_data=json.dumps( dict( ANSIBLE_MODULE_ARGS=dict( foo=False, bar=[1,2,3], bam="bam", baz=u'baz'), ANSIBLE_MODULE_CONSTANTS=dict() - ) - ) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - - def tearDown(self): - sys.stdin = self.real_stdin + ))): + from ansible.module_utils import basic - @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") - def test_module_utils_basic__log_invocation(self): - from ansible.module_utils import basic - - # test basic log invocation - am = basic.AnsibleModule( - argument_spec=dict( - foo = dict(default=True, type='bool'), - bar = dict(default=[], type='list'), - bam = dict(default="bam"), - baz = dict(default=u"baz"), - password = dict(default=True), - no_log = dict(default="you shouldn't see me", no_log=True), - ), - ) + # test basic log invocation + am = basic.AnsibleModule( + argument_spec=dict( + foo = dict(default=True, type='bool'), + bar = dict(default=[], type='list'), + bam = dict(default="bam"), + baz = dict(default=u"baz"), + password = dict(default=True), + no_log = dict(default="you shouldn't see me", no_log=True), + ), + ) - am.log = MagicMock() - am._log_invocation() + am.log = MagicMock() + am._log_invocation() - # Message is generated from a dict so it will be in an unknown order. - # have to check this manually rather than with assert_called_with() - args = am.log.call_args[0] - self.assertEqual(len(args), 1) - message = args[0] + # Message is generated from a dict so it will be in an unknown order. + # have to check this manually rather than with assert_called_with() + args = am.log.call_args[0] + self.assertEqual(len(args), 1) + message = args[0] - self.assertEqual(len(message), len('Invoked with bam=bam bar=[1, 2, 3] foo=False baz=baz no_log=NOT_LOGGING_PARAMETER password=NOT_LOGGING_PASSWORD')) - self.assertTrue(message.startswith('Invoked with ')) - self.assertIn(' bam=bam', message) - self.assertIn(' bar=[1, 2, 3]', message) - self.assertIn(' foo=False', message) - self.assertIn(' baz=baz', message) - self.assertIn(' no_log=NOT_LOGGING_PARAMETER', message) - self.assertIn(' password=NOT_LOGGING_PASSWORD', message) + self.assertEqual(len(message), len('Invoked with bam=bam bar=[1, 2, 3] foo=False baz=baz no_log=NOT_LOGGING_PARAMETER password=NOT_LOGGING_PASSWORD')) + self.assertTrue(message.startswith('Invoked with ')) + self.assertIn(' bam=bam', message) + self.assertIn(' bar=[1, 2, 3]', message) + self.assertIn(' foo=False', message) + self.assertIn(' baz=baz', message) + self.assertIn(' no_log=NOT_LOGGING_PARAMETER', message) + self.assertIn(' password=NOT_LOGGING_PASSWORD', message) - kwargs = am.log.call_args[1] - self.assertEqual(kwargs, - dict(log_args={ - 'foo': 'False', - 'bar': '[1, 2, 3]', - 'bam': 'bam', - 'baz': 'baz', - 'password': 'NOT_LOGGING_PASSWORD', - 'no_log': 'NOT_LOGGING_PARAMETER', - }) - ) + kwargs = am.log.call_args[1] + self.assertEqual(kwargs, + dict(log_args={ + 'foo': 'False', + 'bar': '[1, 2, 3]', + 'bam': 'bam', + 'baz': 'baz', + 'password': 'NOT_LOGGING_PASSWORD', + 'no_log': 'NOT_LOGGING_PARAMETER', + }) + ) diff --git a/test/units/module_utils/basic/test_exit_json.py b/test/units/module_utils/basic/test_exit_json.py index 249dc380d93..1bd25002d4b 100644 --- a/test/units/module_utils/basic/test_exit_json.py +++ b/test/units/module_utils/basic/test_exit_json.py @@ -23,39 +23,34 @@ __metaclass__ = type import copy import json import sys -from io import BytesIO, StringIO -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest +from units.mock.procenv import swap_stdin_and_argv, swap_stdout from ansible.module_utils import basic from ansible.module_utils.basic import heuristic_log_sanitize from ansible.module_utils.basic import return_values, remove_values + empty_invocation = {u'module_args': {}} @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") class TestAnsibleModuleExitJson(unittest.TestCase): - def setUp(self): - self.old_stdin = sys.stdin args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) + self.stdin_swap_ctx = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap_ctx.__enter__() - self.old_stdout = sys.stdout - self.fake_stream = BytesIO() - sys.stdout = self.fake_stream + # since we can't use context managers and "with" without overriding run(), call them directly + self.stdout_swap_ctx = swap_stdout() + self.fake_stream = self.stdout_swap_ctx.__enter__() self.module = basic.AnsibleModule(argument_spec=dict()) def tearDown(self): - sys.stdout = self.old_stdout - sys.stdin = self.old_stdin + # since we can't use context managers and "with" without overriding run(), call them directly to clean up + self.stdin_swap_ctx.__exit__(None, None, None) + self.stdout_swap_ctx.__exit__(None, None, None) def test_exit_json_no_args_exits(self): with self.assertRaises(SystemExit) as ctx: @@ -123,42 +118,24 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): ), ) - def setUp(self): - self.old_stdin = sys.stdin - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - - self.old_stdout = sys.stdout - - def tearDown(self): - sys.stdin = self.old_stdin - sys.stdout = self.old_stdout - def test_exit_json_removes_values(self): self.maxDiff = None for args, return_val, expected in self.dataset: - sys.stdout = BytesIO() params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) params = json.dumps(params) - if PY3: - sys.stdin = StringIO(params) - sys.stdin.buffer = BytesIO(to_bytes(params)) - else: - sys.stdin = BytesIO(to_bytes(params)) - module = basic.AnsibleModule( - argument_spec = dict( - username=dict(), - password=dict(no_log=True), - token=dict(no_log=True), - ), - ) - with self.assertRaises(SystemExit) as ctx: - self.assertEquals(module.exit_json(**return_val), expected) - self.assertEquals(json.loads(sys.stdout.getvalue()), expected) + + with swap_stdin_and_argv(stdin_data=params): + with swap_stdout(): + module = basic.AnsibleModule( + argument_spec = dict( + username=dict(), + password=dict(no_log=True), + token=dict(no_log=True), + ), + ) + with self.assertRaises(SystemExit) as ctx: + self.assertEquals(module.exit_json(**return_val), expected) + self.assertEquals(json.loads(sys.stdout.getvalue()), expected) def test_fail_json_removes_values(self): self.maxDiff = None @@ -166,21 +143,17 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase): expected = copy.deepcopy(expected) del expected['changed'] expected['failed'] = True - sys.stdout = BytesIO() params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={}) params = json.dumps(params) - if PY3: - sys.stdin = StringIO(params) - sys.stdin.buffer = BytesIO(to_bytes(params)) - else: - sys.stdin = BytesIO(to_bytes(params)) - module = basic.AnsibleModule( - argument_spec = dict( - username=dict(), - password=dict(no_log=True), - token=dict(no_log=True), - ), - ) - with self.assertRaises(SystemExit) as ctx: - self.assertEquals(module.fail_json(**return_val), expected) - self.assertEquals(json.loads(sys.stdout.getvalue()), expected) + with swap_stdin_and_argv(stdin_data=params): + with swap_stdout(): + module = basic.AnsibleModule( + argument_spec = dict( + username=dict(), + password=dict(no_log=True), + token=dict(no_log=True), + ), + ) + with self.assertRaises(SystemExit) as ctx: + self.assertEquals(module.fail_json(**return_val), expected) + self.assertEquals(json.loads(sys.stdout.getvalue()), expected) diff --git a/test/units/module_utils/basic/test_log.py b/test/units/module_utils/basic/test_log.py index 0452ce7d903..c846d77096c 100644 --- a/test/units/module_utils/basic/test_log.py +++ b/test/units/module_utils/basic/test_log.py @@ -23,16 +23,14 @@ __metaclass__ = type import sys import json import syslog -from io import BytesIO, StringIO - -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock +from units.mock.procenv import swap_stdin_and_argv from ansible.module_utils import basic + try: # Python 3.4+ from importlib import reload @@ -44,15 +42,12 @@ except ImportError: class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): - def setUp(self): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - self.real_stdin = sys.stdin - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) + + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() self.am = basic.AnsibleModule( argument_spec = dict(), @@ -64,7 +59,8 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase): basic.has_journal = False def tearDown(self): - sys.stdin = self.real_stdin + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) basic.has_journal = self.has_journal def test_smoketest_syslog(self): @@ -84,20 +80,18 @@ class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase): def setUp(self): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - self.real_stdin = sys.stdin - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() self.am = basic.AnsibleModule( argument_spec = dict(), ) def tearDown(self): - sys.stdin = self.real_stdin + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) @unittest.skipUnless(basic.has_journal, 'python systemd bindings not installed') def test_smoketest_journal(self): @@ -134,26 +128,26 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase): def setUp(self): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - self.real_stdin = sys.stdin - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() self.am = basic.AnsibleModule( argument_spec = dict(), ) + self.has_journal = basic.has_journal if self.has_journal: # Systems with journal can still test syslog basic.has_journal = False def tearDown(self): - sys.stdin = self.real_stdin + # teardown/reset basic.has_journal = self.has_journal + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) + @patch('syslog.syslog', autospec=True) def test_no_log(self, mock_func): no_log = self.am.no_log @@ -191,21 +185,22 @@ class TestAnsibleModuleLogJournal(unittest.TestCase): b'non-utf8 :\xff: test': b'non-utf8 :\xff: test'.decode('utf-8', 'replace') } + # overriding run lets us use context managers for setup/teardown-esque behavior def setUp(self): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - self.real_stdin = sys.stdin - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() self.am = basic.AnsibleModule( argument_spec = dict(), ) self.has_journal = basic.has_journal - basic.has_journal = True + if self.has_journal: + # Systems with journal can still test syslog + basic.has_journal = False + self.module_patcher = None # In case systemd-python is not installed @@ -218,9 +213,12 @@ class TestAnsibleModuleLogJournal(unittest.TestCase): self._fake_out_reload(basic) def tearDown(self): - sys.stdin = self.real_stdin + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) + # teardown/reset basic.has_journal = self.has_journal + if self.module_patcher: self.module_patcher.stop() reload(basic) diff --git a/test/units/module_utils/basic/test_run_command.py b/test/units/module_utils/basic/test_run_command.py index 3c563658162..781dfa286eb 100644 --- a/test/units/module_utils/basic/test_run_command.py +++ b/test/units/module_utils/basic/test_run_command.py @@ -25,12 +25,11 @@ import sys import time from io import BytesIO, StringIO -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes - from ansible.compat.tests import unittest from ansible.compat.tests.mock import call, MagicMock, Mock, patch, sentinel +from units.mock.procenv import swap_stdin_and_argv + from ansible.module_utils import basic from ansible.module_utils.basic import AnsibleModule @@ -46,9 +45,7 @@ class OpenBytesIO(BytesIO): @unittest.skipIf(sys.version_info[0] >= 3, "Python 3 is not supported on targets (yet)") class TestAnsibleModuleRunCommand(unittest.TestCase): - def setUp(self): - self.cmd_out = { # os.read() is returning 'bytes', not strings sentinel.stdout: BytesIO(), @@ -66,11 +63,10 @@ class TestAnsibleModuleRunCommand(unittest.TestCase): raise OSError(errno.EPERM, "Permission denied: '/inaccessible'") args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() + self.module = AnsibleModule(argument_spec=dict()) self.module.fail_json = MagicMock(side_effect=SystemExit) @@ -96,6 +92,11 @@ class TestAnsibleModuleRunCommand(unittest.TestCase): self.addCleanup(patch.stopall) + + def tearDown(self): + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) + def test_list_as_args(self): self.module.run_command(['/bin/ls', 'a', ' b', 'c ']) self.assertTrue(self.subprocess.Popen.called) diff --git a/test/units/module_utils/basic/test_safe_eval.py b/test/units/module_utils/basic/test_safe_eval.py index 36e9e1e399f..912dee17e3f 100644 --- a/test/units/module_utils/basic/test_safe_eval.py +++ b/test/units/module_utils/basic/test_safe_eval.py @@ -22,60 +22,48 @@ __metaclass__ = type import sys import json -from io import BytesIO, StringIO from ansible.compat.tests import unittest -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes +from units.mock.procenv import swap_stdin_and_argv class TestAnsibleModuleExitJson(unittest.TestCase): - def setUp(self): - self.real_stdin = sys.stdin - - def tearDown(self): - sys.stdin = self.real_stdin - def test_module_utils_basic_safe_eval(self): from ansible.module_utils import basic args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( - argument_spec=dict(), - ) + with swap_stdin_and_argv(stdin_data=args): + am = basic.AnsibleModule( + argument_spec=dict(), + ) - # test some basic usage - # string (and with exceptions included), integer, bool - self.assertEqual(am.safe_eval("'a'"), 'a') - self.assertEqual(am.safe_eval("'a'", include_exceptions=True), ('a', None)) - self.assertEqual(am.safe_eval("1"), 1) - self.assertEqual(am.safe_eval("True"), True) - self.assertEqual(am.safe_eval("False"), False) - self.assertEqual(am.safe_eval("{}"), {}) - # not passing in a string to convert - self.assertEqual(am.safe_eval({'a':1}), {'a':1}) - self.assertEqual(am.safe_eval({'a':1}, include_exceptions=True), ({'a':1}, None)) - # invalid literal eval - self.assertEqual(am.safe_eval("a=1"), "a=1") - res = am.safe_eval("a=1", include_exceptions=True) - self.assertEqual(res[0], "a=1") - self.assertEqual(type(res[1]), SyntaxError) - self.assertEqual(am.safe_eval("a.foo()"), "a.foo()") - res = am.safe_eval("a.foo()", include_exceptions=True) - self.assertEqual(res[0], "a.foo()") - self.assertEqual(res[1], None) - self.assertEqual(am.safe_eval("import foo"), "import foo") - res = am.safe_eval("import foo", include_exceptions=True) - self.assertEqual(res[0], "import foo") - self.assertEqual(res[1], None) - self.assertEqual(am.safe_eval("__import__('foo')"), "__import__('foo')") - res = am.safe_eval("__import__('foo')", include_exceptions=True) - self.assertEqual(res[0], "__import__('foo')") - self.assertEqual(type(res[1]), ValueError) + # test some basic usage + # string (and with exceptions included), integer, bool + self.assertEqual(am.safe_eval("'a'"), 'a') + self.assertEqual(am.safe_eval("'a'", include_exceptions=True), ('a', None)) + self.assertEqual(am.safe_eval("1"), 1) + self.assertEqual(am.safe_eval("True"), True) + self.assertEqual(am.safe_eval("False"), False) + self.assertEqual(am.safe_eval("{}"), {}) + # not passing in a string to convert + self.assertEqual(am.safe_eval({'a':1}), {'a':1}) + self.assertEqual(am.safe_eval({'a':1}, include_exceptions=True), ({'a':1}, None)) + # invalid literal eval + self.assertEqual(am.safe_eval("a=1"), "a=1") + res = am.safe_eval("a=1", include_exceptions=True) + self.assertEqual(res[0], "a=1") + self.assertEqual(type(res[1]), SyntaxError) + self.assertEqual(am.safe_eval("a.foo()"), "a.foo()") + res = am.safe_eval("a.foo()", include_exceptions=True) + self.assertEqual(res[0], "a.foo()") + self.assertEqual(res[1], None) + self.assertEqual(am.safe_eval("import foo"), "import foo") + res = am.safe_eval("import foo", include_exceptions=True) + self.assertEqual(res[0], "import foo") + self.assertEqual(res[1], None) + self.assertEqual(am.safe_eval("__import__('foo')"), "__import__('foo')") + res = am.safe_eval("__import__('foo')", include_exceptions=True) + self.assertEqual(res[0], "__import__('foo')") + self.assertEqual(type(res[1]), ValueError) diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index f8c96c65368..8beefed6e88 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -31,8 +31,7 @@ try: except ImportError: import __builtin__ as builtins -from ansible.compat.six import PY3 -from ansible.utils.unicode import to_bytes +from units.mock.procenv import swap_stdin_and_argv from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call @@ -40,12 +39,16 @@ from ansible.compat.tests.mock import patch, MagicMock, mock_open, Mock, call realimport = builtins.__import__ class TestModuleUtilsBasic(unittest.TestCase): - + def setUp(self): - self.real_stdin = sys.stdin + args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap.__enter__() def tearDown(self): - sys.stdin = self.real_stdin + # unittest doesn't have a clean place to use a context manager, so we have to enter/exit manually + self.stdin_swap.__exit__(None, None, None) def clear_modules(self, mods): for mod in mods: @@ -271,13 +274,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_creation(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec=dict(), ) @@ -293,94 +289,71 @@ class TestModuleUtilsBasic(unittest.TestCase): # should test ok args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( - argument_spec = arg_spec, - mutually_exclusive = mut_ex, - required_together = req_to, - no_log=True, - check_invalid_arguments=False, - add_file_common_args=True, - supports_check_mode=True, - ) + with swap_stdin_and_argv(stdin_data=args): + am = basic.AnsibleModule( + argument_spec = arg_spec, + mutually_exclusive = mut_ex, + required_together = req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) # FIXME: add asserts here to verify the basic config # fail, because a required param was not specified args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - - self.assertRaises( - SystemExit, - basic.AnsibleModule, - argument_spec = arg_spec, - mutually_exclusive = mut_ex, - required_together = req_to, - no_log=True, - check_invalid_arguments=False, - add_file_common_args=True, - supports_check_mode=True, - ) + + with swap_stdin_and_argv(stdin_data=args): + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec = arg_spec, + mutually_exclusive = mut_ex, + required_together = req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) # fail because of mutually exclusive parameters args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - - self.assertRaises( - SystemExit, - basic.AnsibleModule, - argument_spec = arg_spec, - mutually_exclusive = mut_ex, - required_together = req_to, - no_log=True, - check_invalid_arguments=False, - add_file_common_args=True, - supports_check_mode=True, - ) + + with swap_stdin_and_argv(stdin_data=args): + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec = arg_spec, + mutually_exclusive = mut_ex, + required_together = req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) # fail because a param required due to another param was not specified args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - - self.assertRaises( - SystemExit, - basic.AnsibleModule, - argument_spec = arg_spec, - mutually_exclusive = mut_ex, - required_together = req_to, - no_log=True, - check_invalid_arguments=False, - add_file_common_args=True, - supports_check_mode=True, - ) + + with swap_stdin_and_argv(stdin_data=args): + self.assertRaises( + SystemExit, + basic.AnsibleModule, + argument_spec = arg_spec, + mutually_exclusive = mut_ex, + required_together = req_to, + no_log=True, + check_invalid_arguments=False, + add_file_common_args=True, + supports_check_mode=True, + ) def test_module_utils_basic_ansible_module_load_file_common_arguments(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -429,13 +402,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_mls_enabled(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -455,13 +421,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_initial_context(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -475,13 +434,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_enabled(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -513,13 +465,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_default_context(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -555,13 +500,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_selinux_context(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -604,59 +542,49 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( - argument_spec = dict(), - ) - print(am.constants) + with swap_stdin_and_argv(stdin_data=args): + + am = basic.AnsibleModule( + argument_spec = dict(), + ) + print(am.constants) - def _mock_find_mount_point(path): - if path.startswith('/some/path'): - return '/some/path' - elif path.startswith('/weird/random/fstype'): - return '/weird/random/fstype' - return '/' + def _mock_find_mount_point(path): + if path.startswith('/some/path'): + return '/some/path' + elif path.startswith('/weird/random/fstype'): + return '/weird/random/fstype' + return '/' - am.find_mount_point = MagicMock(side_effect=_mock_find_mount_point) - am.selinux_context = MagicMock(return_value=['foo_u', 'foo_r', 'foo_t', 's0']) + am.find_mount_point = MagicMock(side_effect=_mock_find_mount_point) + am.selinux_context = MagicMock(return_value=['foo_u', 'foo_r', 'foo_t', 's0']) - m = mock_open() - m.side_effect = OSError + m = mock_open() + m.side_effect = OSError - with patch.object(builtins, 'open', m, create=True): - self.assertEqual(am.is_special_selinux_path('/some/path/that/should/be/nfs'), (False, None)) + with patch.object(builtins, 'open', m, create=True): + self.assertEqual(am.is_special_selinux_path('/some/path/that/should/be/nfs'), (False, None)) - mount_data = [ - '/dev/disk1 / ext4 rw,seclabel,relatime,data=ordered 0 0\n', - '1.1.1.1:/path/to/nfs /some/path nfs ro 0 0\n', - 'whatever /weird/random/fstype foos rw 0 0\n', - ] + mount_data = [ + '/dev/disk1 / ext4 rw,seclabel,relatime,data=ordered 0 0\n', + '1.1.1.1:/path/to/nfs /some/path nfs ro 0 0\n', + 'whatever /weird/random/fstype foos rw 0 0\n', + ] - # mock_open has a broken readlines() implementation apparently... - # this should work by default but doesn't, so we fix it - m = mock_open(read_data=''.join(mount_data)) - m.return_value.readlines.return_value = mount_data + # mock_open has a broken readlines() implementation apparently... + # this should work by default but doesn't, so we fix it + m = mock_open(read_data=''.join(mount_data)) + m.return_value.readlines.return_value = mount_data - with patch.object(builtins, 'open', m, create=True): - self.assertEqual(am.is_special_selinux_path('/some/random/path'), (False, None)) - self.assertEqual(am.is_special_selinux_path('/some/path/that/should/be/nfs'), (True, ['foo_u', 'foo_r', 'foo_t', 's0'])) - self.assertEqual(am.is_special_selinux_path('/weird/random/fstype/path'), (True, ['foo_u', 'foo_r', 'foo_t', 's0'])) + with patch.object(builtins, 'open', m, create=True): + self.assertEqual(am.is_special_selinux_path('/some/random/path'), (False, None)) + self.assertEqual(am.is_special_selinux_path('/some/path/that/should/be/nfs'), (True, ['foo_u', 'foo_r', 'foo_t', 's0'])) + self.assertEqual(am.is_special_selinux_path('/weird/random/fstype/path'), (True, ['foo_u', 'foo_r', 'foo_t', 's0'])) def test_module_utils_basic_ansible_module_to_filesystem_str(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -667,13 +595,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_user_and_group(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -688,13 +609,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_find_mount_point(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -718,13 +632,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_context_if_different(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -769,13 +676,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_owner_if_different(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -814,13 +714,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_group_if_different(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -859,13 +752,6 @@ class TestModuleUtilsBasic(unittest.TestCase): def test_module_utils_basic_ansible_module_set_mode_if_different(self): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -953,13 +839,6 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) @@ -1137,13 +1016,6 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic - args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - am = basic.AnsibleModule( argument_spec = dict(), ) diff --git a/test/units/module_utils/test_distribution_version.py b/test/units/module_utils/test_distribution_version.py index e3efae5da81..e3e5f3d3916 100644 --- a/test/units/module_utils/test_distribution_version.py +++ b/test/units/module_utils/test_distribution_version.py @@ -26,6 +26,8 @@ from io import BytesIO, StringIO from ansible.compat.six import PY3 from ansible.utils.unicode import to_bytes +from units.mock.procenv import swap_stdin_and_argv + # for testing from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch @@ -323,26 +325,17 @@ def test_distribution_version(): # needs to be in here, because the import fails with python3 still import ansible.module_utils.facts as facts - real_stdin = sys.stdin from ansible.module_utils import basic args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={})) - if PY3: - sys.stdin = StringIO(args) - sys.stdin.buffer = BytesIO(to_bytes(args)) - else: - sys.stdin = BytesIO(to_bytes(args)) - module = basic.AnsibleModule(argument_spec=dict()) - - for t in TESTSETS: - # run individual tests via generator - # set nicer stdout output for nosetest - _test_one_distribution.description = "check distribution_version for %s" % t['name'] - yield _test_one_distribution, facts, module, t - - - sys.stdin = real_stdin + with swap_stdin_and_argv(stdin_data=args): + module = basic.AnsibleModule(argument_spec=dict()) + for t in TESTSETS: + # run individual tests via generator + # set nicer stdout output for nosetest + _test_one_distribution.description = "check distribution_version for %s" % t['name'] + yield _test_one_distribution, facts, module, t def _test_one_distribution(facts, module, testcase): """run the test on one distribution testcase