diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dd755e6b31..126175fb2d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ Ansible Changes By Release - Also added an ansible-config CLI to allow for listing config options and dumping current config (including origin) - TODO: build upon this to add many features detailed in ansible-config proposal https://github.com/ansible/proposals/issues/35 * Windows modules now support the use of multiple shared module_utils files in the form of Powershell modules (.psm1), via `#Requires -Module Ansible.ModuleUtils.Whatever.psm1` +* Python module argument_spec now supports custom validation logic by accepting a callable as the `type` argument. ### Deprecations * The behaviour when specifying `--tags` (or `--skip-tags`) multiple times on the command line diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index 0f25c6c7783..456cb4589eb 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -1874,22 +1874,28 @@ class AnsibleModule(object): wanted = v.get('type', None) if k not in param: continue - if wanted is None: - # Mostly we want to default to str. - # For values set to None explicitly, return None instead as - # that allows a user to unset a parameter - if param[k] is None: - continue - wanted = 'str' value = param[k] if value is None: continue - try: - type_checker = self._CHECK_ARGUMENT_TYPES_DISPATCHER[wanted] - except KeyError: - self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) + if not callable(wanted): + if wanted is None: + # Mostly we want to default to str. + # For values set to None explicitly, return None instead as + # that allows a user to unset a parameter + if param[k] is None: + continue + wanted = 'str' + try: + type_checker = self._CHECK_ARGUMENT_TYPES_DISPATCHER[wanted] + except KeyError: + self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) + else: + # set the type_checker to the callable, and reset wanted to the callable's name (or type if it doesn't have one, ala MagicMock) + type_checker = wanted + wanted = getattr(wanted, '__name__', to_native(type(wanted))) + try: param[k] = type_checker(value) except (TypeError, ValueError): diff --git a/test/units/module_utils/basic/test_argument_spec.py b/test/units/module_utils/basic/test_argument_spec.py new file mode 100644 index 00000000000..f6f9689f726 --- /dev/null +++ b/test/units/module_utils/basic/test_argument_spec.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division) +__metaclass__ = type + +import json + +from ansible.compat.tests import unittest +from ansible.compat.tests.mock import MagicMock +from units.mock.procenv import swap_stdin_and_argv, swap_stdout +from ansible.module_utils import basic + + +class TestCallableTypeValidation(unittest.TestCase): + def setUp(self): + args = json.dumps(dict(ANSIBLE_MODULE_ARGS=dict(arg="42"))) + self.stdin_swap_ctx = swap_stdin_and_argv(stdin_data=args) + self.stdin_swap_ctx.__enter__() + + # since we can't use context managers and "with" without overriding run(), call them directly + self.stdout_swap_ctx = swap_stdout() + self.fake_stream = self.stdout_swap_ctx.__enter__() + + basic._ANSIBLE_ARGS = None + + def tearDown(self): + # since we can't use context managers and "with" without overriding run(), call them directly to clean up + self.stdin_swap_ctx.__exit__(None, None, None) + self.stdout_swap_ctx.__exit__(None, None, None) + + def test_validate_success(self): + mock_validator = MagicMock(return_value=42) + m = basic.AnsibleModule(argument_spec=dict( + arg=dict(type=mock_validator) + )) + + self.assertTrue(mock_validator.called) + self.assertEqual(m.params['arg'], 42) + self.assertEqual(type(m.params['arg']), int) + + def test_validate_fail(self): + mock_validator = MagicMock(side_effect=TypeError("bad conversion")) + with self.assertRaises(SystemExit) as ecm: + m = basic.AnsibleModule(argument_spec=dict( + arg=dict(type=mock_validator) + )) + + self.assertIn("bad conversion", json.loads(self.fake_stream.getvalue())['msg'])