From be6db1a730270a8e89636da9630dcac8e3e093fc Mon Sep 17 00:00:00 2001 From: Toshio Kuratomi Date: Mon, 29 Jun 2015 08:05:58 -0700 Subject: [PATCH] Refactor the argspec type checking and add path as a type --- lib/ansible/module_utils/basic.py | 146 ++++++++++++++++++------------ 1 file changed, 90 insertions(+), 56 deletions(-) diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py index ffd159601d6..e89809ff12e 100644 --- a/lib/ansible/module_utils/basic.py +++ b/lib/ansible/module_utils/basic.py @@ -351,9 +351,9 @@ class AnsibleModule(object): self.check_mode = False self.no_log = no_log self.cleanup_files = [] - + self.aliases = {} - + if add_file_common_args: for k, v in FILE_COMMON_ARGUMENTS.iteritems(): if k not in self.argument_spec: @@ -366,7 +366,7 @@ class AnsibleModule(object): self.params = self._load_params() self._legal_inputs = ['_ansible_check_mode', '_ansible_no_log'] - + self.aliases = self._handle_aliases() if check_invalid_arguments: @@ -380,6 +380,16 @@ class AnsibleModule(object): self._set_defaults(pre=True) + + self._CHECK_ARGUMENT_TYPES_DISPATCHER = { + 'str': self._check_type_str, + 'list': self._check_type_list, + 'dict': self._check_type_dict, + 'bool': self._check_type_bool, + 'int': self._check_type_int, + 'float': self._check_type_float, + 'path': self._check_type_path, + } if not bypass_checks: self._check_required_arguments() self._check_argument_values() @@ -1021,6 +1031,76 @@ class AnsibleModule(object): return (str, e) return str + def _check_type_str(self, value): + if isinstance(value, basestring): + return value + # Note: This could throw a unicode error if value's __str__() method + # returns non-ascii. Have to port utils.to_bytes() if that happens + return str(value) + + def _check_type_list(self, value): + if isinstance(value, list): + return value + + if isinstance(value, basestring): + return value.split(",") + elif isinstance(value, int) or isinstance(value, float): + return [ str(value) ] + + raise TypeError('%s cannot be converted to a list' % type(value)) + + def _check_type_dict(self, value): + if isinstance(value, dict): + return value + + if isinstance(value, basestring): + if value.startswith("{"): + try: + return json.loads(value) + except: + (result, exc) = self.safe_eval(value, dict(), include_exceptions=True) + if exc is not None: + raise TypeError('unable to evaluate string as dictionary') + return result + elif '=' in value: + return dict([x.strip().split("=", 1) for x in value.split(",")]) + else: + raise TypeError("dictionary requested, could not parse JSON or key=value") + + raise TypeError('%s cannot be converted to a dict' % type(value)) + + def _check_type_bool(self, value): + if isinstance(value, bool): + return value + + if isinstance(value, basestring): + return self.boolean(value) + + raise TypeError('%s cannot be converted to a bool' % type(value)) + + def _check_type_int(self, value): + if isinstance(value, int): + return value + + if isinstance(value, basestring): + return int(value) + + raise TypeError('%s cannot be converted to an int' % type(value)) + + def _check_type_float(self, value): + if isinstance(value, float): + return value + + if isinstance(value, basestring): + return float(value) + + raise TypeError('%s cannot be converted to a float' % type(value)) + + def _check_type_path(self, value): + value = self._check_type_str(value) + return os.path.expanduser(os.path.expandvars(value)) + + def _check_argument_types(self): ''' ensure all arguments have the requested type ''' for (k, v) in self.argument_spec.iteritems(): @@ -1034,59 +1114,13 @@ class AnsibleModule(object): is_invalid = False try: - if wanted == 'str': - if not isinstance(value, basestring): - self.params[k] = str(value) - elif wanted == 'list': - if not isinstance(value, list): - if isinstance(value, basestring): - self.params[k] = value.split(",") - elif isinstance(value, int) or isinstance(value, float): - self.params[k] = [ str(value) ] - else: - is_invalid = True - elif wanted == 'dict': - if not isinstance(value, dict): - if isinstance(value, basestring): - if value.startswith("{"): - try: - self.params[k] = json.loads(value) - except: - (result, exc) = self.safe_eval(value, dict(), include_exceptions=True) - if exc is not None: - self.fail_json(msg="unable to evaluate dictionary for %s" % k) - self.params[k] = result - elif '=' in value: - self.params[k] = dict([x.strip().split("=", 1) for x in value.split(",")]) - else: - self.fail_json(msg="dictionary requested, could not parse JSON or key=value") - else: - is_invalid = True - elif wanted == 'bool': - if not isinstance(value, bool): - if isinstance(value, basestring): - self.params[k] = self.boolean(value) - else: - is_invalid = True - elif wanted == 'int': - if not isinstance(value, int): - if isinstance(value, basestring): - self.params[k] = int(value) - else: - is_invalid = True - elif wanted == 'float': - if not isinstance(value, float): - if isinstance(value, basestring): - self.params[k] = float(value) - else: - is_invalid = True - else: - self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) - - if is_invalid: - self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted)) - except ValueError: - self.fail_json(msg="value of argument %s is not of type %s and we were unable to automatically convert" % (k, wanted)) + type_checker = self._CHECK_ARGUMENT_TYPES_DISPATCHER[wanted] + except KeyError: + self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k)) + try: + self.params[k] = type_checker(value) + except (TypeError, ValueError): + self.fail_json(msg="argument %s is of type %s and we were unable to convert to %s" % (k, type(value), wanted)) def _set_defaults(self, pre=True): for (k,v) in self.argument_spec.iteritems():