From 9076f8eb3187830744233c96a03756d1028bd17f Mon Sep 17 00:00:00 2001 From: Daniel Hokka Zakrisson Date: Sat, 23 Feb 2013 19:43:50 +0100 Subject: [PATCH] Add type checking for module arguments, converting as much as possible Converts to list from comma-separated strings, and to dicts from comma-separated, key=value strings. Fixes #2126. --- lib/ansible/module_common.py | 42 +++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/lib/ansible/module_common.py b/lib/ansible/module_common.py index 7b3b37af846..06527c7680d 100644 --- a/lib/ansible/module_common.py +++ b/lib/ansible/module_common.py @@ -179,6 +179,7 @@ class AnsibleModule(object): if not bypass_checks: self._check_required_arguments() + self._check_argument_values() self._check_argument_types() self._check_mutually_exclusive(mutually_exclusive) self._check_required_together(required_together) @@ -535,7 +536,7 @@ class AnsibleModule(object): if len(missing) > 0: self.fail_json(msg="missing required arguments: %s" % ",".join(missing)) - def _check_argument_types(self): + def _check_argument_values(self): ''' ensure all arguments have the requested values, and there are no stray arguments ''' for (k,v) in self.argument_spec.iteritems(): choices = v.get('choices',None) @@ -550,6 +551,45 @@ class AnsibleModule(object): else: self.fail_json(msg="internal error: do not know how to interpret argument_spec") + def _check_argument_types(self): + ''' ensure all arguments have the requested type ''' + for (k, v) in self.argument_spec.iteritems(): + wanted = v.get('type', None) + if wanted is None: + continue + if k not in self.params: + continue + + value = self.params[k] + is_invalid = False + + if wanted == 'str': + if not isinstance(value, basestring): + self.params[k] = str(value) + elif wanted == 'list': + if not isinstance(value, list): + if isinstance(value, basestring): + self.params[k] = value.split(",") + else: + is_invalid = True + elif wanted == 'dict': + if not isinstance(value, dict): + if isinstance(value, basestring): + self.params[k] = dict([x.split("=", 1) for x in value.split(",")]) + else: + is_invalid = True + elif wanted == 'bool': + if not isinstance(value, bool): + if isinstance(value, basestring): + self.params[k] = self.boolean(value) + else: + is_invalid = True + else: + self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) + + if is_invalid: + self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted)) + def _set_defaults(self, pre=True): for (k,v) in self.argument_spec.iteritems(): default = v.get('default', None)