Allow AnsibleModules to be instantiated more than once in a module

Fix SELINUX monkeypatch in test_basic
pull/15562/head
Toshio Kuratomi 9 years ago
parent 0f373c1767
commit 44e21f7062

@ -223,6 +223,12 @@ from ansible import __version__
# Backwards compat. New code should just import and use __version__
ANSIBLE_VERSION = __version__
# Internal global holding passed in params and constants. This is consulted
# in case multiple AnsibleModules are created. Otherwise each AnsibleModule
# would attempt to read from stdin. Other code should not use this directly
# as it is an internal implementation detail
_ANSIBLE_ARGS = None
FILE_COMMON_ARGUMENTS=dict(
src = dict(),
mode = dict(type='raw'),
@ -1457,6 +1463,10 @@ class AnsibleModule(object):
''' read the input and set the params attribute. Sets the constants as well.'''
# debug overrides to read args from file or cmdline
global _ANSIBLE_ARGS
if _ANSIBLE_ARGS is not None:
buffer = _ANSIBLE_ARGS
else:
# Avoid tracebacks when locale is non-utf8
# We control the args and we pass them as utf8
if len(sys.argv) > 1:
@ -1474,6 +1484,7 @@ class AnsibleModule(object):
buffer = sys.stdin.read()
else:
buffer = sys.stdin.buffer.read()
_ANSIBLE_ARGS = buffer
try:
params = json.loads(buffer.decode('utf-8'))

@ -45,6 +45,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
self.stdout_swap_ctx = swap_stdout()
self.fake_stream = self.stdout_swap_ctx.__enter__()
reload(basic)
self.module = basic.AnsibleModule(argument_spec=dict())
def tearDown(self):
@ -125,6 +126,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
params = json.dumps(params)
with swap_stdin_and_argv(stdin_data=params):
reload(basic)
with swap_stdout():
module = basic.AnsibleModule(
argument_spec = dict(
@ -146,6 +148,7 @@ class TestAnsibleModuleExitValuesRemoved(unittest.TestCase):
params = dict(ANSIBLE_MODULE_ARGS=args, ANSIBLE_MODULE_CONSTANTS={})
params = json.dumps(params)
with swap_stdin_and_argv(stdin_data=params):
reload(basic)
with swap_stdout():
module = basic.AnsibleModule(
argument_spec = dict(

@ -49,6 +49,7 @@ class TestAnsibleModuleSysLogSmokeTest(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule(
argument_spec = dict(),
)
@ -85,6 +86,7 @@ class TestAnsibleModuleJournaldSmokeTest(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule(
argument_spec = dict(),
)
@ -132,6 +134,7 @@ class TestAnsibleModuleLogSyslog(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule(
argument_spec = dict(),
)
@ -192,6 +195,7 @@ class TestAnsibleModuleLogJournal(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__()
reload(basic)
self.am = basic.AnsibleModule(
argument_spec = dict(),
)

@ -67,6 +67,7 @@ class TestAnsibleModuleRunCommand(unittest.TestCase):
self.stdin_swap = swap_stdin_and_argv(stdin_data=args)
self.stdin_swap.__enter__()
reload(basic)
self.module = AnsibleModule(argument_spec=dict())
self.module.fail_json = MagicMock(side_effect=SystemExit)

@ -26,6 +26,12 @@ import json
from ansible.compat.tests import unittest
from units.mock.procenv import swap_stdin_and_argv
try:
from importlib import reload
except:
# Py2 has reload as a builtin
pass
class TestAnsibleModuleExitJson(unittest.TestCase):
def test_module_utils_basic_safe_eval(self):
@ -34,6 +40,7 @@ class TestAnsibleModuleExitJson(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args):
reload(basic)
am = basic.AnsibleModule(
argument_spec=dict(),
)

@ -31,6 +31,12 @@ try:
except ImportError:
import __builtin__ as builtins
try:
from importlib import reload
except:
# Py2 has reload as a builtin
pass
from units.mock.procenv import swap_stdin_and_argv
from ansible.compat.tests import unittest
@ -291,6 +297,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo": "hello"}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args):
reload(basic)
am = basic.AnsibleModule(
argument_spec = arg_spec,
mutually_exclusive = mut_ex,
@ -307,6 +314,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={}))
with swap_stdin_and_argv(stdin_data=args):
reload(basic)
self.assertRaises(
SystemExit,
basic.AnsibleModule,
@ -353,6 +361,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_load_file_common_arguments(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -401,6 +410,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_mls_enabled(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -420,6 +430,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_initial_context(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -433,6 +444,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_enabled(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -464,6 +476,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_default_context(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -499,6 +512,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_selinux_context(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -540,6 +554,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
from ansible.module_utils import basic
reload(basic)
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}, ANSIBLE_MODULE_CONSTANTS={"SELINUX_SPECIAL_FS": "nfs,nfsd,foos"}))
@ -584,6 +599,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_to_filesystem_str(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -608,6 +624,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_find_mount_point(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -631,18 +648,19 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_context_if_different(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
)
basic.HAS_SELINUX = False
basic.HAVE_SELINUX = False
am.selinux_enabled = MagicMock(return_value=False)
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], True), True)
self.assertEqual(am.set_context_if_different('/path/to/file', ['foo_u', 'foo_r', 'foo_t', 's0'], False), False)
basic.HAS_SELINUX = True
basic.HAVE_SELINUX = True
am.selinux_enabled = MagicMock(return_value=True)
am.selinux_context = MagicMock(return_value=['bar_u', 'bar_r', None, None])
@ -675,6 +693,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_owner_if_different(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -713,6 +732,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_group_if_different(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -751,6 +771,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module_set_mode_if_different(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -838,6 +859,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),
@ -1015,6 +1037,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
def test_module_utils_basic_ansible_module__symbolic_mode_to_octal(self):
from ansible.module_utils import basic
reload(basic)
am = basic.AnsibleModule(
argument_spec = dict(),

Loading…
Cancel
Save