Use ArgumentSpecValidator in AnsibleModule (#73703)

* Begin using ArgumentSpecValidator in AnsibleModule

* Add check parameters to ArgumentSpecValidator

Add additional parameters for specifying required and mutually exclusive parameters.
Add code to the .validate() method that runs these additional checks.

* Make errors related to unsupported parameters match existing behavior

Update the punctuation in the message slightly to make it more readable.
Add a property to ArgumentSpecValidator to hold valid parameter names.

* Set default values after performining checks

* FIx sanity test failure

* Use correct parameters when checking sub options

* Use a dict when iterating over check functions

Referencing by key names makes things a bit more readable IMO.

* Fix bug in comparison for sub options evaluation

* Add options_context to check functions

This allows the parent parameter to be added the the error message if a validation
error occurs in a sub option.

* Fix bug in apply_defaults behavior of sub spec validation

* Accept options_conext in get_unsupported_parameters()

If options_context is supplied, a tuple of parent key names of unsupported parameter will be
created. This allows the full "path" to the unsupported parameter to be reported.

* Build path to the unsupported parameter for error messages.

* Remove unused import

* Update recursive finder test

* Skip if running in check mode

This was done in the _check_arguments() method. That was moved to a function that has no
way of calling fail_json(), so it must be done outside of validation.

This is a silght change in behavior, but I believe the correct one.

Previously, only unsupported parameters would cause a failure. All other checks would not be executed
if the modlue did not support check mode. This would hide validation failures in check mode.

* The great purge

Remove all methods related to argument spec validation from AnsibleModule

* Keep _name and kind in the caller and out of the validator

This seems a bit awkward since this means the caller could end up with {name} and {kind} in
the error message if they don't run the messages through the .format() method
with name and kind parameters.

* Double moustaches work

I wasn't sure if they get stripped or not. Looks like they do. Neat trick.

* Add changelog

* Update unsupported parameter test

The error message changed to include name and kind.

* Remove unused import

* Add better documentation for ArgumentSpecValidator class

* Fix example

* Few more docs fixes

* Mark required and mutually exclusive attributes as private

* Mark validate functions as private

* Reorganize functions in validation.py

* Remove unused imports in basic.py related to argument spec validation

* Create errors is module_utils

We have errors in lib/ansible/errors/ but those cannot be used by modules.

* Update recursive finder test

* Move errors to file rather than __init__.py

* Change ArgumentSpecValidator.validate() interface

Raise AnsibleValidationErrorMultiple on validation error which contains all AnsibleValidationError
exceptions for validation failures.

Return the validated parameters if validation is successful rather than True/False.

Update docs and tests.

* Get attribute in loop so that the attribute name can also be used as a parameter

* Shorten line

* Update calling code in AnsibleModule for new validator interface

* Update calling code in validate_argument_spec based in new validation interface

* Base custom exception class off of Exception

* Call the __init__ method of the base Exception class to populate args

* Ensure no_log values are always updated

* Make custom exceptions more hierarchical

This redefines AnsibleError from lib/ansible/errors with a different signature since that cannot
be used by modules. This may be a bad idea. Maybe lib/ansible/errors should be moved to
module_utils, or AnsibleError defined in this commit should use the same signature as the original.

* Just go back to basing off Exception

* Return ValidationResult object on successful validation

Create a ValidationResult class.
Return a ValidationResult from ArgumentSpecValidator.validate() when validation is successful.
Update class and method docs.
Update unit tests based on interface change.

* Make it easier to get error objects from AnsibleValidationResultMultiple

This makes the interface cleaner when getting individual error objects contained in a single
AnsibleValidationResultMultiple instance.

* Define custom exception for each type of validation failure

These errors indicate where a validation error occured. Currently they are empty but could
contain specific data for each exception type in the future.

* Update tests based on (yet another) interface change

* Mark several more functions as private

These are all doing rather "internal" things. The ArgumentSpecValidator class is the preferred
public interface.

* Move warnings and deprecations to result object

Rather than calling deprecate() and warn() directly, store them on the result object so the
caller can decide what to do with them.

* Use subclass for module arg spec validation

The subclass uses global warning and deprecations feature

* Fix up docs

* Remove legal_inputs munging from _handle_aliases()

This is done in AnsibleModule by the _set_internal_properties() method. It only makes sense
to do that for an AnsibleModule instance (it should update the parameters before performing
validation) and shouldn't be done by the validator.

Create a private function just for getting legal inputs since that is done in a couple of places.

It may make sense store that on the ValidationResult object.

* Increase test coverage

* Remove unnecessary conditional

ci_complete

* Mark warnings and deprecations as private in the ValidationResult

They can be made public once we come up with a way to make them more generally useful,
probably by creating cusom objects to store the data in more structure way.

* Mark valid_parameter_names as private and populate it during initialization

* Use a global for storing the list of additonal checks to perform

This list is used by the main validate method as well as the sub spec validation.
pull/73979/head
Sam Doran 4 years ago committed by GitHub
parent 089d0a0508
commit abacf6a108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,5 @@
major_changes:
- >-
AnsibleModule - use ``ArgumentSpecValidator`` class for validating argument spec and remove
private methods related to argument spec validation. Any modules using private methods
should now use the ``ArgumentSpecValidator`` class or the appropriate validation function.

@ -90,6 +90,8 @@ from ansible.module_utils.common.text.converters import (
container_to_text as json_dict_bytes_to_unicode, container_to_text as json_dict_bytes_to_unicode,
) )
from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator
from ansible.module_utils.common.text.formatters import ( from ansible.module_utils.common.text.formatters import (
lenient_lowercase, lenient_lowercase,
bytes_to_human, bytes_to_human,
@ -155,25 +157,15 @@ from ansible.module_utils.common.sys_info import (
) )
from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.pycompat24 import get_exception, literal_eval
from ansible.module_utils.common.parameters import ( from ansible.module_utils.common.parameters import (
_remove_values_conditions,
_sanitize_keys_conditions,
sanitize_keys,
env_fallback, env_fallback,
get_unsupported_parameters,
get_type_validator,
handle_aliases,
list_deprecations,
list_no_log_values,
remove_values, remove_values,
set_defaults, sanitize_keys,
set_fallbacks,
validate_argument_types,
AnsibleFallbackNotFound,
DEFAULT_TYPE_VALIDATORS, DEFAULT_TYPE_VALIDATORS,
PASS_VARS, PASS_VARS,
PASS_BOOLS, PASS_BOOLS,
) )
from ansible.module_utils.errors import AnsibleFallbackNotFound, AnsibleValidationErrorMultiple, UnsupportedError
from ansible.module_utils.six import ( from ansible.module_utils.six import (
PY2, PY2,
PY3, PY3,
@ -187,24 +179,6 @@ from ansible.module_utils.six import (
from ansible.module_utils.six.moves import map, reduce, shlex_quote from ansible.module_utils.six.moves import map, reduce, shlex_quote
from ansible.module_utils.common.validation import ( from ansible.module_utils.common.validation import (
check_missing_parameters, check_missing_parameters,
check_mutually_exclusive,
check_required_arguments,
check_required_by,
check_required_if,
check_required_one_of,
check_required_together,
count_terms,
check_type_bool,
check_type_bits,
check_type_bytes,
check_type_float,
check_type_int,
check_type_jsonarg,
check_type_list,
check_type_dict,
check_type_path,
check_type_raw,
check_type_str,
safe_eval, safe_eval,
) )
from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses
@ -507,48 +481,43 @@ class AnsibleModule(object):
# Save parameter values that should never be logged # Save parameter values that should never be logged
self.no_log_values = set() self.no_log_values = set()
self._load_params()
self._set_fallbacks()
# append to legal_inputs and then possibly check against them
try:
self.aliases = self._handle_aliases()
except (ValueError, TypeError) as e:
# Use exceptions here because it isn't safe to call fail_json until no_log is processed
print('\n{"failed": true, "msg": "Module alias error: %s"}' % to_native(e))
sys.exit(1)
self._handle_no_log_values()
# check the locale as set by the current environment, and reset to # check the locale as set by the current environment, and reset to
# a known valid (LANG=C) if it's an invalid/unavailable locale # a known valid (LANG=C) if it's an invalid/unavailable locale
self._check_locale() self._check_locale()
self._load_params()
self._set_internal_properties() self._set_internal_properties()
self._check_arguments()
# check exclusive early self.validator = ModuleArgumentSpecValidator(self.argument_spec,
if not bypass_checks: self.mutually_exclusive,
self._check_mutually_exclusive(mutually_exclusive) self.required_together,
self.required_one_of,
self.required_if,
self.required_by,
)
self._set_defaults(pre=True) self.validation_result = self.validator.validate(self.params)
self.params.update(self.validation_result.validated_parameters)
self.no_log_values.update(self.validation_result._no_log_values)
# This is for backwards compatibility only. try:
self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS error = self.validation_result.errors[0]
except IndexError:
error = None
if not bypass_checks: # Fail for validation errors, even in check mode
self._check_required_arguments() if error:
self._check_argument_types() msg = self.validation_result.errors.msg
self._check_argument_values() if isinstance(error, UnsupportedError):
self._check_required_together(required_together) msg = "Unsupported parameters for ({name}) {kind}: {msg}".format(name=self._name, kind='module', msg=msg)
self._check_required_one_of(required_one_of)
self._check_required_if(required_if) self.fail_json(msg=msg)
self._check_required_by(required_by)
self._set_defaults(pre=False) if self.check_mode and not self.supports_check_mode:
self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name)
# deal with options sub-spec # This is for backwards compatibility only.
self._handle_options() self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS
if not self.no_log: if not self.no_log:
self._log_invocation() self._log_invocation()
@ -1274,42 +1243,6 @@ class AnsibleModule(object):
self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" %
to_native(e), exception=traceback.format_exc()) to_native(e), exception=traceback.format_exc())
def _handle_aliases(self, spec=None, param=None, option_prefix=''):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
# this uses exceptions as it happens before we can safely call fail_json
alias_warnings = []
alias_deprecations = []
alias_results, self._legal_inputs = handle_aliases(spec, param, alias_warnings, alias_deprecations)
for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option_prefix + option, option_prefix + alias))
for deprecation in alias_deprecations:
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'],
version=deprecation.get('version'), date=deprecation.get('date'),
collection_name=deprecation.get('collection_name'))
return alias_results
def _handle_no_log_values(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
self.no_log_values.update(list_no_log_values(spec, param))
except TypeError as te:
self.fail_json(msg="Failure when processing no_log parameters. Module invocation will be hidden. "
"%s" % to_native(te), invocation={'module_args': 'HIDDEN DUE TO FAILURE'})
for message in list_deprecations(spec, param):
deprecate(message['msg'], version=message.get('version'), date=message.get('date'),
collection_name=message.get('collection_name'))
def _set_internal_properties(self, argument_spec=None, module_parameters=None): def _set_internal_properties(self, argument_spec=None, module_parameters=None):
if argument_spec is None: if argument_spec is None:
argument_spec = self.argument_spec argument_spec = self.argument_spec
@ -1333,344 +1266,9 @@ class AnsibleModule(object):
if not hasattr(self, PASS_VARS[k][0]): if not hasattr(self, PASS_VARS[k][0]):
setattr(self, PASS_VARS[k][0], PASS_VARS[k][1]) setattr(self, PASS_VARS[k][0], PASS_VARS[k][1])
def _check_arguments(self, spec=None, param=None, legal_inputs=None):
unsupported_parameters = set()
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
if legal_inputs is None:
legal_inputs = self._legal_inputs
unsupported_parameters = get_unsupported_parameters(spec, param, legal_inputs)
if unsupported_parameters:
msg = "Unsupported parameters for (%s) module: %s" % (self._name, ', '.join(sorted(list(unsupported_parameters))))
if self._options_context:
msg += " found in %s." % " -> ".join(self._options_context)
supported_parameters = list()
for key in sorted(spec.keys()):
if 'aliases' in spec[key] and spec[key]['aliases']:
supported_parameters.append("%s (%s)" % (key, ', '.join(sorted(spec[key]['aliases']))))
else:
supported_parameters.append(key)
msg += " Supported parameters include: %s" % (', '.join(supported_parameters))
self.fail_json(msg=msg)
if self.check_mode and not self.supports_check_mode:
self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name)
def _count_terms(self, check, param=None):
if param is None:
param = self.params
return count_terms(check, param)
def _check_mutually_exclusive(self, spec, param=None):
if param is None:
param = self.params
try:
check_mutually_exclusive(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_one_of(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_one_of(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_together(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_together(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_by(self, spec, param=None):
if spec is None:
return
if param is None:
param = self.params
try:
check_required_by(spec, param)
except TypeError as e:
self.fail_json(msg=to_native(e))
def _check_required_arguments(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
try:
check_required_arguments(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_required_if(self, spec, param=None):
''' ensure that parameters which conditionally required are present '''
if spec is None:
return
if param is None:
param = self.params
try:
check_required_if(spec, param)
except TypeError as e:
msg = to_native(e)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def _check_argument_values(self, spec=None, param=None):
''' ensure all arguments have the requested values, and there are no stray arguments '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
for (k, v) in spec.items():
choices = v.get('choices', None)
if choices is None:
continue
if isinstance(choices, SEQUENCETYPE) and not isinstance(choices, (binary_type, text_type)):
if k in param:
# Allow one or more when type='list' param with choices
if isinstance(param[k], list):
diff_list = ", ".join([item for item in param[k] if item not in choices])
if diff_list:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one or more of: %s. Got no match for: %s" % (k, choices_str, diff_list)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
elif param[k] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible.
lowered_choices = None
if param[k] == 'False':
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_FALSE.intersection(choices)
if len(overlap) == 1:
# Extract from a set
(param[k],) = overlap
if param[k] == 'True':
if lowered_choices is None:
lowered_choices = lenient_lowercase(choices)
overlap = BOOLEANS_TRUE.intersection(choices)
if len(overlap) == 1:
(param[k],) = overlap
if param[k] not in choices:
choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one of: %s, got: %s" % (k, choices_str, param[k])
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
else:
msg = "internal error: choices for argument %s are not iterable: %s" % (k, choices)
if self._options_context:
msg += " found in %s" % " -> ".join(self._options_context)
self.fail_json(msg=msg)
def safe_eval(self, value, locals=None, include_exceptions=False): def safe_eval(self, value, locals=None, include_exceptions=False):
return safe_eval(value, locals, include_exceptions) return safe_eval(value, locals, include_exceptions)
def _check_type_str(self, value, param=None, prefix=''):
opts = {
'error': False,
'warn': False,
'ignore': True
}
# Ignore, warn, or error when converting to a string.
allow_conversion = opts.get(self._string_conversion_action, True)
try:
return check_type_str(value, allow_conversion)
except TypeError:
common_msg = 'quote the entire value to ensure it does not change.'
from_msg = '{0!r}'.format(value)
to_msg = '{0!r}'.format(to_text(value))
if param is not None:
if prefix:
param = '{0}{1}'.format(prefix, param)
from_msg = '{0}: {1!r}'.format(param, value)
to_msg = '{0}: {1!r}'.format(param, to_text(value))
if self._string_conversion_action == 'error':
msg = common_msg.capitalize()
raise TypeError(to_native(msg))
elif self._string_conversion_action == 'warn':
msg = ('The value "{0}" (type {1.__class__.__name__}) was converted to "{2}" (type string). '
'If this does not look like what you expect, {3}').format(from_msg, value, to_msg, common_msg)
self.warn(to_native(msg))
return to_native(value, errors='surrogate_or_strict')
def _check_type_list(self, value):
return check_type_list(value)
def _check_type_dict(self, value):
return check_type_dict(value)
def _check_type_bool(self, value):
return check_type_bool(value)
def _check_type_int(self, value):
return check_type_int(value)
def _check_type_float(self, value):
return check_type_float(value)
def _check_type_path(self, value):
return check_type_path(value)
def _check_type_jsonarg(self, value):
return check_type_jsonarg(value)
def _check_type_raw(self, value):
return check_type_raw(value)
def _check_type_bytes(self, value):
return check_type_bytes(value)
def _check_type_bits(self, value):
return check_type_bits(value)
def _handle_options(self, argument_spec=None, params=None, prefix=''):
''' deal with options to create sub spec '''
if argument_spec is None:
argument_spec = self.argument_spec
if params is None:
params = self.params
for (k, v) in argument_spec.items():
wanted = v.get('type', None)
if wanted == 'dict' or (wanted == 'list' and v.get('elements', '') == 'dict'):
spec = v.get('options', None)
if v.get('apply_defaults', False):
if spec is not None:
if params.get(k) is None:
params[k] = {}
else:
continue
elif spec is None or k not in params or params[k] is None:
continue
self._options_context.append(k)
if isinstance(params[k], dict):
elements = [params[k]]
else:
elements = params[k]
for idx, param in enumerate(elements):
if not isinstance(param, dict):
self.fail_json(msg="value of %s must be of type dict or list of dict" % k)
new_prefix = prefix + k
if wanted == 'list':
new_prefix += '[%d]' % idx
new_prefix += '.'
self._set_fallbacks(spec, param)
options_aliases = self._handle_aliases(spec, param, option_prefix=new_prefix)
options_legal_inputs = list(spec.keys()) + list(options_aliases.keys())
self._check_arguments(spec, param, options_legal_inputs)
# check exclusive early
if not self.bypass_checks:
self._check_mutually_exclusive(v.get('mutually_exclusive', None), param)
self._set_defaults(pre=True, spec=spec, param=param)
if not self.bypass_checks:
self._check_required_arguments(spec, param)
self._check_argument_types(spec, param, new_prefix)
self._check_argument_values(spec, param)
self._check_required_together(v.get('required_together', None), param)
self._check_required_one_of(v.get('required_one_of', None), param)
self._check_required_if(v.get('required_if', None), param)
self._check_required_by(v.get('required_by', None), param)
self._set_defaults(pre=False, spec=spec, param=param)
# handle multi level options (sub argspec)
self._handle_options(spec, param, new_prefix)
self._options_context.pop()
def _get_wanted_type(self, wanted, k):
# Use the private method for 'str' type to handle the string conversion warning.
if wanted == 'str':
type_checker, wanted = self._check_type_str, 'str'
else:
type_checker, wanted = get_type_validator(wanted)
if type_checker is None:
self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
return type_checker, wanted
def _check_argument_types(self, spec=None, param=None, prefix=''):
''' ensure all arguments have the requested type '''
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
errors = []
validate_argument_types(spec, param, errors=errors)
if errors:
self.fail_json(msg=errors[0])
def _set_defaults(self, pre=True, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
# The interface for set_defaults is different than _set_defaults()
# The third parameter controls whether or not defaults are actually set.
set_default = not pre
self.no_log_values.update(set_defaults(spec, param, set_default))
def _set_fallbacks(self, spec=None, param=None):
if spec is None:
spec = self.argument_spec
if param is None:
param = self.params
self.no_log_values.update(set_fallbacks(spec, param))
def _load_params(self): def _load_params(self):
''' read the input and set the params attribute. ''' read the input and set the params attribute.

@ -5,71 +5,146 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
from copy import deepcopy from copy import deepcopy
from ansible.module_utils.common._collections_compat import (
Sequence,
)
from ansible.module_utils.common.parameters import ( from ansible.module_utils.common.parameters import (
get_unsupported_parameters, _ADDITIONAL_CHECKS,
handle_aliases, _get_legal_inputs,
list_no_log_values, _get_unsupported_parameters,
remove_values, _handle_aliases,
set_defaults, _list_no_log_values,
_set_defaults,
_validate_argument_types,
_validate_argument_values,
_validate_sub_spec,
set_fallbacks, set_fallbacks,
validate_argument_types,
validate_argument_values,
validate_sub_spec,
) )
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.warnings import deprecate, warn from ansible.module_utils.common.warnings import deprecate, warn
from ansible.module_utils.common.validation import ( from ansible.module_utils.common.validation import (
check_mutually_exclusive,
check_required_arguments, check_required_arguments,
check_required_by,
check_required_if,
check_required_one_of,
check_required_together,
)
from ansible.module_utils.errors import (
AliasError,
AnsibleValidationErrorMultiple,
MutuallyExclusiveError,
NoLogError,
RequiredByError,
RequiredDefaultError,
RequiredError,
RequiredIfError,
RequiredOneOfError,
RequiredTogetherError,
UnsupportedError,
) )
from ansible.module_utils.six import string_types
class ValidationResult:
"""Result of argument spec validation.
class ArgumentSpecValidator(): :param parameters: Terms to be validated and coerced to the correct type.
"""Argument spec validation class""" :type parameters: dict
"""
def __init__(self, argument_spec, parameters): def __init__(self, parameters):
self._error_messages = []
self._no_log_values = set() self._no_log_values = set()
self.argument_spec = argument_spec
# Make a copy of the original parameters to avoid changing them
self._validated_parameters = deepcopy(parameters)
self._unsupported_parameters = set() self._unsupported_parameters = set()
self._validated_parameters = deepcopy(parameters)
@property self._deprecations = []
def error_messages(self): self._warnings = []
return self._error_messages self.errors = AnsibleValidationErrorMultiple()
@property @property
def validated_parameters(self): def validated_parameters(self):
return self._validated_parameters return self._validated_parameters
def _add_error(self, error): @property
if isinstance(error, string_types): def unsupported_parameters(self):
self._error_messages.append(error) return self._unsupported_parameters
elif isinstance(error, Sequence):
self._error_messages.extend(error) @property
def error_messages(self):
return self.errors.messages
class ArgumentSpecValidator:
"""Argument spec validation class
Creates a validator based on the ``argument_spec`` that can be used to
validate a number of parameters using the ``validate()`` method.
:param argument_spec: Specification of valid parameters and their type. May
include nested argument specs.
:type argument_spec: dict
:param mutually_exclusive: List or list of lists of terms that should not
be provided together.
:type mutually_exclusive: list, optional
:param required_together: List of lists of terms that are required together.
:type required_together: list, optional
:param required_one_of: List of lists of terms, one of which in each list
is required.
:type required_one_of: list, optional
:param required_if: List of lists of ``[parameter, value, [parameters]]`` where
one of [parameters] is required if ``parameter`` == ``value``.
:type required_if: list, optional
:param required_by: Dictionary of parameter names that contain a list of
parameters required by each key in the dictionary.
:type required_by: dict, optional
"""
def __init__(self, argument_spec,
mutually_exclusive=None,
required_together=None,
required_one_of=None,
required_if=None,
required_by=None,
):
self._mutually_exclusive = mutually_exclusive
self._required_together = required_together
self._required_one_of = required_one_of
self._required_if = required_if
self._required_by = required_by
self._valid_parameter_names = set()
self.argument_spec = argument_spec
for key in sorted(self.argument_spec.keys()):
aliases = self.argument_spec[key].get('aliases')
if aliases:
self._valid_parameter_names.update(["{key} ({aliases})".format(key=key, aliases=", ".join(sorted(aliases)))])
else: else:
raise ValueError('Error messages must be a string or sequence not a %s' % type(error)) self._valid_parameter_names.update([key])
def _sanitize_error_messages(self): def validate(self, parameters, *args, **kwargs):
self._error_messages = remove_values(self._error_messages, self._no_log_values) """Validate module parameters against argument spec. Returns a
ValidationResult object.
def validate(self, *args, **kwargs): Error messages in the ValidationResult may contain no_log values and should be
"""Validate module parameters against argument spec. sanitized before logging or displaying.
:Example: :Example:
validator = ArgumentSpecValidator(argument_spec, parameters) validator = ArgumentSpecValidator(argument_spec)
passeded = validator.validate() result = validator.validate(parameters)
if result.error_messages:
sys.exit("Validation failed: {0}".format(", ".join(result.error_messages))
valid_params = result.validated_parameters
:param argument_spec: Specification of parameters, type, and valid values :param argument_spec: Specification of parameters, type, and valid values
:type argument_spec: dict :type argument_spec: dict
@ -77,58 +152,104 @@ class ArgumentSpecValidator():
:param parameters: Parameters provided to the role :param parameters: Parameters provided to the role
:type parameters: dict :type parameters: dict
:returns: True if no errors were encountered, False if any errors were encountered. :return: Object containing validated parameters.
:rtype: bool :rtype: ValidationResult
""" """
self._no_log_values.update(set_fallbacks(self.argument_spec, self._validated_parameters)) result = ValidationResult(parameters)
result._no_log_values.update(set_fallbacks(self.argument_spec, result._validated_parameters))
alias_warnings = [] alias_warnings = []
alias_deprecations = [] alias_deprecations = []
try: try:
alias_results, legal_inputs = handle_aliases(self.argument_spec, self._validated_parameters, alias_warnings, alias_deprecations) aliases = _handle_aliases(self.argument_spec, result._validated_parameters, alias_warnings, alias_deprecations)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
alias_results = {} aliases = {}
legal_inputs = None result.errors.append(AliasError(to_native(e)))
self._add_error(to_native(e))
legal_inputs = _get_legal_inputs(self.argument_spec, result._validated_parameters, aliases)
for option, alias in alias_warnings: for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option, alias)) result._warnings.append({'option': option, 'alias': alias})
for deprecation in alias_deprecations: for deprecation in alias_deprecations:
deprecate("Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], result._deprecations.append({
version=deprecation.get('version'), date=deprecation.get('date'), 'name': deprecation['name'],
collection_name=deprecation.get('collection_name')) 'version': deprecation.get('version'),
'date': deprecation.get('date'),
'collection_name': deprecation.get('collection_name'),
})
try:
result._no_log_values.update(_list_no_log_values(self.argument_spec, result._validated_parameters))
except TypeError as te:
result.errors.append(NoLogError(to_native(te)))
self._no_log_values.update(list_no_log_values(self.argument_spec, self._validated_parameters)) try:
result._unsupported_parameters.update(_get_unsupported_parameters(self.argument_spec, result._validated_parameters, legal_inputs))
except TypeError as te:
result.errors.append(RequiredDefaultError(to_native(te)))
except ValueError as ve:
result.errors.append(AliasError(to_native(ve)))
if legal_inputs is None: try:
legal_inputs = list(alias_results.keys()) + list(self.argument_spec.keys()) check_mutually_exclusive(self._mutually_exclusive, result._validated_parameters)
self._unsupported_parameters.update(get_unsupported_parameters(self.argument_spec, self._validated_parameters, legal_inputs)) except TypeError as te:
result.errors.append(MutuallyExclusiveError(to_native(te)))
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters, False)) result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters, False))
try: try:
check_required_arguments(self.argument_spec, self._validated_parameters) check_required_arguments(self.argument_spec, result._validated_parameters)
except TypeError as e: except TypeError as e:
self._add_error(to_native(e)) result.errors.append(RequiredError(to_native(e)))
_validate_argument_types(self.argument_spec, result._validated_parameters, errors=result.errors)
_validate_argument_values(self.argument_spec, result._validated_parameters, errors=result.errors)
validate_argument_types(self.argument_spec, self._validated_parameters, errors=self._error_messages) for check in _ADDITIONAL_CHECKS:
validate_argument_values(self.argument_spec, self._validated_parameters, errors=self._error_messages) try:
check['func'](getattr(self, "_{attr}".format(attr=check['attr'])), result._validated_parameters)
except TypeError as te:
result.errors.append(check['err'](to_native(te)))
result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters))
_validate_sub_spec(self.argument_spec, result._validated_parameters,
errors=result.errors,
no_log_values=result._no_log_values,
unsupported_parameters=result._unsupported_parameters)
if result._unsupported_parameters:
flattened_names = []
for item in result._unsupported_parameters:
if isinstance(item, tuple):
flattened_names.append(".".join(item))
else:
flattened_names.append(item)
self._no_log_values.update(set_defaults(self.argument_spec, self._validated_parameters)) unsupported_string = ", ".join(sorted(list(flattened_names)))
supported_string = ", ".join(self._valid_parameter_names)
result.errors.append(
UnsupportedError("{0}. Supported parameters include: {1}.".format(unsupported_string, supported_string)))
validate_sub_spec(self.argument_spec, self._validated_parameters, return result
errors=self._error_messages,
no_log_values=self._no_log_values,
unsupported_parameters=self._unsupported_parameters)
if self._unsupported_parameters:
self._add_error('Unsupported parameters: %s' % ', '.join(sorted(list(self._unsupported_parameters))))
self._sanitize_error_messages() class ModuleArgumentSpecValidator(ArgumentSpecValidator):
def __init__(self, *args, **kwargs):
super(ModuleArgumentSpecValidator, self).__init__(*args, **kwargs)
if self.error_messages: def validate(self, parameters):
return False result = super(ModuleArgumentSpecValidator, self).validate(parameters)
else:
return True for d in result._deprecations:
deprecate("Alias '{name}' is deprecated. See the module docs for more information".format(name=d['name']),
version=d.get('version'), date=d.get('date'),
collection_name=d.get('collection_name'))
for w in result._warnings:
warn('Both option {option} and its alias {alias} are set.'.format(option=w['option'], alias=w['alias']))
return result

@ -15,6 +15,22 @@ from ansible.module_utils.common.collections import is_iterable
from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text
from ansible.module_utils.common.text.formatters import lenient_lowercase from ansible.module_utils.common.text.formatters import lenient_lowercase
from ansible.module_utils.common.warnings import warn from ansible.module_utils.common.warnings import warn
from ansible.module_utils.errors import (
AliasError,
AnsibleFallbackNotFound,
AnsibleValidationErrorMultiple,
ArgumentTypeError,
ArgumentValueError,
ElementError,
MutuallyExclusiveError,
NoLogError,
RequiredByError,
RequiredError,
RequiredIfError,
RequiredOneOfError,
RequiredTogetherError,
SubParameterTypeError,
)
from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE
from ansible.module_utils.common._collections_compat import ( from ansible.module_utils.common._collections_compat import (
@ -59,6 +75,13 @@ from ansible.module_utils.common.validation import (
# Python2 & 3 way to get NoneType # Python2 & 3 way to get NoneType
NoneType = type(None) NoneType = type(None)
_ADDITIONAL_CHECKS = (
{'func': check_required_together, 'attr': 'required_together', 'err': RequiredTogetherError},
{'func': check_required_one_of, 'attr': 'required_one_of', 'err': RequiredOneOfError},
{'func': check_required_if, 'attr': 'required_if', 'err': RequiredIfError},
{'func': check_required_by, 'attr': 'required_by', 'err': RequiredByError},
)
# if adding boolean attribute, also add to PASS_BOOL # if adding boolean attribute, also add to PASS_BOOL
# some of this dupes defaults from controller config # some of this dupes defaults from controller config
PASS_VARS = { PASS_VARS = {
@ -97,8 +120,221 @@ DEFAULT_TYPE_VALIDATORS = {
} }
class AnsibleFallbackNotFound(Exception): def _get_type_validator(wanted):
pass """Returns the callable used to validate a wanted type and the type name.
:arg wanted: String or callable. If a string, get the corresponding
validation function from DEFAULT_TYPE_VALIDATORS. If callable,
get the name of the custom callable and return that for the type_checker.
:returns: Tuple of callable function or None, and a string that is the name
of the wanted type.
"""
# Use one our our builtin validators.
if not callable(wanted):
if wanted is None:
# Default type for parameters
wanted = 'str'
type_checker = DEFAULT_TYPE_VALIDATORS.get(wanted)
# Use the custom callable for validation.
else:
type_checker = wanted
wanted = getattr(wanted, '__name__', to_native(type(wanted)))
return type_checker, wanted
def _get_legal_inputs(argument_spec, parameters, aliases=None):
if aliases is None:
aliases = _handle_aliases(argument_spec, parameters)
return list(aliases.keys()) + list(argument_spec.keys())
def _get_unsupported_parameters(argument_spec, parameters, legal_inputs=None, options_context=None):
"""Check keys in parameters against those provided in legal_inputs
to ensure they contain legal values. If legal_inputs are not supplied,
they will be generated using the argument_spec.
:arg argument_spec: Dictionary of parameters, their type, and valid values.
:arg parameters: Dictionary of parameters.
:arg legal_inputs: List of valid key names property names. Overrides values
in argument_spec.
:arg options_context: List of parent keys for tracking the context of where
a parameter is defined.
:returns: Set of unsupported parameters. Empty set if no unsupported parameters
are found.
"""
if legal_inputs is None:
legal_inputs = _get_legal_inputs(argument_spec, parameters)
unsupported_parameters = set()
for k in parameters.keys():
if k not in legal_inputs:
context = k
if options_context:
context = tuple(options_context + [k])
unsupported_parameters.add(context)
return unsupported_parameters
def _handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None):
"""Process aliases from an argument_spec including warnings and deprecations.
Modify ``parameters`` by adding a new key for each alias with the supplied
value from ``parameters``.
If a list is provided to the alias_warnings parameter, it will be filled with tuples
(option, alias) in every case where both an option and its alias are specified.
If a list is provided to alias_deprecations, it will be populated with dictionaries,
each containing deprecation information for each alias found in argument_spec.
:param argument_spec: Dictionary of parameters, their type, and valid values.
:type argument_spec: dict
:param parameters: Dictionary of parameters.
:type parameters: dict
:param alias_warnings:
:type alias_warnings: list
:param alias_deprecations:
:type alias_deprecations: list
"""
aliases_results = {} # alias:canon
for (k, v) in argument_spec.items():
aliases = v.get('aliases', None)
default = v.get('default', None)
required = v.get('required', False)
if alias_deprecations is not None:
for alias in argument_spec[k].get('deprecated_aliases', []):
if alias.get('name') in parameters:
alias_deprecations.append(alias)
if default is not None and required:
# not alias specific but this is a good place to check this
raise ValueError("internal error: required and default are mutually exclusive for %s" % k)
if aliases is None:
continue
if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)):
raise TypeError('internal error: aliases must be a list or tuple')
for alias in aliases:
aliases_results[alias] = k
if alias in parameters:
if k in parameters and alias_warnings is not None:
alias_warnings.append((k, alias))
parameters[k] = parameters[alias]
return aliases_results
def _list_deprecations(argument_spec, parameters, prefix=''):
"""Return a list of deprecations
:arg argument_spec: An argument spec dictionary
:arg parameters: Dictionary of parameters
:returns: List of dictionaries containing a message and version in which
the deprecated parameter will be removed, or an empty list::
[{'msg': "Param 'deptest' is deprecated. See the module docs for more information", 'version': '2.9'}]
"""
deprecations = []
for arg_name, arg_opts in argument_spec.items():
if arg_name in parameters:
if prefix:
sub_prefix = '%s["%s"]' % (prefix, arg_name)
else:
sub_prefix = arg_name
if arg_opts.get('removed_at_date') is not None:
deprecations.append({
'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix,
'date': arg_opts.get('removed_at_date'),
'collection_name': arg_opts.get('removed_from_collection'),
})
elif arg_opts.get('removed_in_version') is not None:
deprecations.append({
'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix,
'version': arg_opts.get('removed_in_version'),
'collection_name': arg_opts.get('removed_from_collection'),
})
# Check sub-argument spec
sub_argument_spec = arg_opts.get('options')
if sub_argument_spec is not None:
sub_arguments = parameters[arg_name]
if isinstance(sub_arguments, Mapping):
sub_arguments = [sub_arguments]
if isinstance(sub_arguments, list):
for sub_params in sub_arguments:
if isinstance(sub_params, Mapping):
deprecations.extend(_list_deprecations(sub_argument_spec, sub_params, prefix=sub_prefix))
return deprecations
def _list_no_log_values(argument_spec, params):
"""Return set of no log values
:arg argument_spec: An argument spec dictionary
:arg params: Dictionary of all parameters
:returns: Set of strings that should be hidden from output::
{'secret_dict_value', 'secret_list_item_one', 'secret_list_item_two', 'secret_string'}
"""
no_log_values = set()
for arg_name, arg_opts in argument_spec.items():
if arg_opts.get('no_log', False):
# Find the value for the no_log'd param
no_log_object = params.get(arg_name, None)
if no_log_object:
try:
no_log_values.update(_return_datastructure_name(no_log_object))
except TypeError as e:
raise TypeError('Failed to convert "%s": %s' % (arg_name, to_native(e)))
# Get no_log values from suboptions
sub_argument_spec = arg_opts.get('options')
if sub_argument_spec is not None:
wanted_type = arg_opts.get('type')
sub_parameters = params.get(arg_name)
if sub_parameters is not None:
if wanted_type == 'dict' or (wanted_type == 'list' and arg_opts.get('elements', '') == 'dict'):
# Sub parameters can be a dict or list of dicts. Ensure parameters are always a list.
if not isinstance(sub_parameters, list):
sub_parameters = [sub_parameters]
for sub_param in sub_parameters:
# Validate dict fields in case they came in as strings
if isinstance(sub_param, string_types):
sub_param = check_type_dict(sub_param)
if not isinstance(sub_param, Mapping):
raise TypeError("Value '{1}' in the sub parameter field '{0}' must by a {2}, "
"not '{1.__class__.__name__}'".format(arg_name, sub_param, wanted_type))
no_log_values.update(_list_no_log_values(sub_argument_spec, sub_param))
return no_log_values
def _return_datastructure_name(obj): def _return_datastructure_name(obj):
@ -217,6 +453,43 @@ def _remove_values_conditions(value, no_log_strings, deferred_removals):
return value return value
def _set_defaults(argument_spec, parameters, set_default=True):
"""Set default values for parameters when no value is supplied.
Modifies parameters directly.
:param argument_spec: Argument spec
:type argument_spec: dict
:param parameters: Parameters to evaluate
:type parameters: dict
:param set_default: Whether or not to set the default values
:type set_default: bool
:returns: Set of strings that should not be logged.
:rtype: set
"""
no_log_values = set()
for param, value in argument_spec.items():
# TODO: Change the default value from None to Sentinel to differentiate between
# user supplied None and a default value set by this function.
default = value.get('default', None)
# This prevents setting defaults on required items on the 1st run,
# otherwise will set things without a default to None on the 2nd.
if param not in parameters and (default is not None or set_default):
# Make sure any default value for no_log fields are masked.
if value.get('no_log', False) and default:
no_log_values.add(default)
parameters[param] = default
return no_log_values
def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals):
""" Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """ """ Helper method to sanitize_keys() to build deferred_removals and avoid deep recursion. """
if isinstance(value, (text_type, binary_type)): if isinstance(value, (text_type, binary_type)):
@ -255,366 +528,23 @@ def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_remov
raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) raise TypeError('Value of unknown type: %s, %s' % (type(value), value))
def env_fallback(*args, **kwargs): def _validate_elements(wanted_type, parameter, values, options_context=None, errors=None):
"""Load value from environment variable"""
for arg in args: if errors is None:
if arg in os.environ: errors = AnsibleValidationErrorMultiple()
return os.environ[arg]
raise AnsibleFallbackNotFound
type_checker, wanted_element_type = _get_type_validator(wanted_type)
validated_parameters = []
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_element_type == 'str' and isinstance(wanted_type, string_types):
if isinstance(parameter, string_types):
kwargs['param'] = parameter
elif isinstance(parameter, dict):
kwargs['param'] = list(parameter.keys())[0]
def set_fallbacks(argument_spec, parameters): for value in values:
no_log_values = set()
for param, value in argument_spec.items():
fallback = value.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if param not in parameters and fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
else:
if value.get('no_log', False) and fallback_value:
no_log_values.add(fallback_value)
parameters[param] = fallback_value
return no_log_values
def set_defaults(argument_spec, parameters, set_default=True):
"""Set default values for parameters when no value is supplied.
Modifies parameters directly.
:param argument_spec: Argument spec
:type argument_spec: dict
:param parameters: Parameters to evaluate
:type parameters: dict
:param set_default: Whether or not to set the default values
:type set_default: bool
:returns: Set of strings that should not be logged.
:rtype: set
"""
no_log_values = set()
for param, value in argument_spec.items():
# TODO: Change the default value from None to Sentinel to differentiate between
# user supplied None and a default value set by this function.
default = value.get('default', None)
# This prevents setting defaults on required items on the 1st run,
# otherwise will set things without a default to None on the 2nd.
if param not in parameters and (default is not None or set_default):
# Make sure any default value for no_log fields are masked.
if value.get('no_log', False) and default:
no_log_values.add(default)
parameters[param] = default
return no_log_values
def list_no_log_values(argument_spec, params):
"""Return set of no log values
:arg argument_spec: An argument spec dictionary from a module
:arg params: Dictionary of all parameters
:returns: Set of strings that should be hidden from output::
{'secret_dict_value', 'secret_list_item_one', 'secret_list_item_two', 'secret_string'}
"""
no_log_values = set()
for arg_name, arg_opts in argument_spec.items():
if arg_opts.get('no_log', False):
# Find the value for the no_log'd param
no_log_object = params.get(arg_name, None)
if no_log_object:
try:
no_log_values.update(_return_datastructure_name(no_log_object))
except TypeError as e:
raise TypeError('Failed to convert "%s": %s' % (arg_name, to_native(e)))
# Get no_log values from suboptions
sub_argument_spec = arg_opts.get('options')
if sub_argument_spec is not None:
wanted_type = arg_opts.get('type')
sub_parameters = params.get(arg_name)
if sub_parameters is not None:
if wanted_type == 'dict' or (wanted_type == 'list' and arg_opts.get('elements', '') == 'dict'):
# Sub parameters can be a dict or list of dicts. Ensure parameters are always a list.
if not isinstance(sub_parameters, list):
sub_parameters = [sub_parameters]
for sub_param in sub_parameters:
# Validate dict fields in case they came in as strings
if isinstance(sub_param, string_types):
sub_param = check_type_dict(sub_param)
if not isinstance(sub_param, Mapping):
raise TypeError("Value '{1}' in the sub parameter field '{0}' must by a {2}, "
"not '{1.__class__.__name__}'".format(arg_name, sub_param, wanted_type))
no_log_values.update(list_no_log_values(sub_argument_spec, sub_param))
return no_log_values
def list_deprecations(argument_spec, parameters, prefix=''):
"""Return a list of deprecations
:arg argument_spec: An argument spec dictionary from a module
:arg parameters: Dictionary of parameters
:returns: List of dictionaries containing a message and version in which
the deprecated parameter will be removed, or an empty list::
[{'msg': "Param 'deptest' is deprecated. See the module docs for more information", 'version': '2.9'}]
"""
deprecations = []
for arg_name, arg_opts in argument_spec.items():
if arg_name in parameters:
if prefix:
sub_prefix = '%s["%s"]' % (prefix, arg_name)
else:
sub_prefix = arg_name
if arg_opts.get('removed_at_date') is not None:
deprecations.append({
'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix,
'date': arg_opts.get('removed_at_date'),
'collection_name': arg_opts.get('removed_from_collection'),
})
elif arg_opts.get('removed_in_version') is not None:
deprecations.append({
'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix,
'version': arg_opts.get('removed_in_version'),
'collection_name': arg_opts.get('removed_from_collection'),
})
# Check sub-argument spec
sub_argument_spec = arg_opts.get('options')
if sub_argument_spec is not None:
sub_arguments = parameters[arg_name]
if isinstance(sub_arguments, Mapping):
sub_arguments = [sub_arguments]
if isinstance(sub_arguments, list):
for sub_params in sub_arguments:
if isinstance(sub_params, Mapping):
deprecations.extend(list_deprecations(sub_argument_spec, sub_params, prefix=sub_prefix))
return deprecations
def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()):
""" Sanitize the keys in a container object by removing no_log values from key names.
This is a companion function to the `remove_values()` function. Similar to that function,
we make use of deferred_removals to avoid hitting maximum recursion depth in cases of
large data structures.
:param obj: The container object to sanitize. Non-container objects are returned unmodified.
:param no_log_strings: A set of string values we do not want logged.
:param ignore_keys: A set of string values of keys to not sanitize.
:returns: An object with sanitized keys.
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
if old_key in ignore_keys or old_key.startswith('_ansible'):
new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
# Sanitize the old key. We take advantage of the sanitizing code in
# _remove_values_conditions() rather than recreating it here.
new_key = _remove_values_conditions(old_key, no_log_strings, None)
new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
for elem in old_data:
new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from keys')
return new_value
def remove_values(value, no_log_strings):
""" Remove strings in no_log_strings from value. If value is a container
type, then remove a lot more.
Use of deferred_removals exists, rather than a pure recursive solution,
because of the potential to hit the maximum recursion depth when dealing with
large amounts of data (see issue #24560).
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _remove_values_conditions(value, no_log_strings, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals)
new_data[old_key] = new_elem
else:
for elem in old_data:
new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from output')
return new_value
def handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None):
"""Return a two item tuple. The first is a dictionary of aliases, the second is
a list of legal inputs.
Modify supplied parameters by adding a new key for each alias.
If a list is provided to the alias_warnings parameter, it will be filled with tuples
(option, alias) in every case where both an option and its alias are specified.
If a list is provided to alias_deprecations, it will be populated with dictionaries,
each containing deprecation information for each alias found in argument_spec.
"""
legal_inputs = ['_ansible_%s' % k for k in PASS_VARS]
aliases_results = {} # alias:canon
for (k, v) in argument_spec.items():
legal_inputs.append(k)
aliases = v.get('aliases', None)
default = v.get('default', None)
required = v.get('required', False)
if alias_deprecations is not None:
for alias in argument_spec[k].get('deprecated_aliases', []):
if alias.get('name') in parameters:
alias_deprecations.append(alias)
if default is not None and required:
# not alias specific but this is a good place to check this
raise ValueError("internal error: required and default are mutually exclusive for %s" % k)
if aliases is None:
continue
if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)):
raise TypeError('internal error: aliases must be a list or tuple')
for alias in aliases:
legal_inputs.append(alias)
aliases_results[alias] = k
if alias in parameters:
if k in parameters and alias_warnings is not None:
alias_warnings.append((k, alias))
parameters[k] = parameters[alias]
return aliases_results, legal_inputs
def get_unsupported_parameters(argument_spec, parameters, legal_inputs=None):
"""Check keys in parameters against those provided in legal_inputs
to ensure they contain legal values. If legal_inputs are not supplied,
they will be generated using the argument_spec.
:arg argument_spec: Dictionary of parameters, their type, and valid values.
:arg parameters: Dictionary of parameters.
:arg legal_inputs: List of valid key names property names. Overrides values
in argument_spec.
:returns: Set of unsupported parameters. Empty set if no unsupported parameters
are found.
"""
if legal_inputs is None:
aliases, legal_inputs = handle_aliases(argument_spec, parameters)
unsupported_parameters = set()
for k in parameters.keys():
if k not in legal_inputs:
unsupported_parameters.add(k)
return unsupported_parameters
def get_type_validator(wanted):
"""Returns the callable used to validate a wanted type and the type name.
:arg wanted: String or callable. If a string, get the corresponding
validation function from DEFAULT_TYPE_VALIDATORS. If callable,
get the name of the custom callable and return that for the type_checker.
:returns: Tuple of callable function or None, and a string that is the name
of the wanted type.
"""
# Use one our our builtin validators.
if not callable(wanted):
if wanted is None:
# Default type for parameters
wanted = 'str'
type_checker = DEFAULT_TYPE_VALIDATORS.get(wanted)
# Use the custom callable for validation.
else:
type_checker = wanted
wanted = getattr(wanted, '__name__', to_native(type(wanted)))
return type_checker, wanted
def validate_elements(wanted_type, parameter, values, options_context=None, errors=None):
if errors is None:
errors = []
type_checker, wanted_element_type = get_type_validator(wanted_type)
validated_parameters = []
# Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {}
if wanted_element_type == 'str' and isinstance(wanted_type, string_types):
if isinstance(parameter, string_types):
kwargs['param'] = parameter
elif isinstance(parameter, dict):
kwargs['param'] = list(parameter.keys())[0]
for value in values:
try: try:
validated_parameters.append(type_checker(value, **kwargs)) validated_parameters.append(type_checker(value, **kwargs))
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
@ -622,11 +552,11 @@ def validate_elements(wanted_type, parameter, values, options_context=None, erro
if options_context: if options_context:
msg += " found in '%s'" % " -> ".join(options_context) msg += " found in '%s'" % " -> ".join(options_context)
msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e)) msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e))
errors.append(msg) errors.append(ElementError(msg))
return validated_parameters return validated_parameters
def validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None): def _validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None):
"""Validate that parameter types match the type in the argument spec. """Validate that parameter types match the type in the argument spec.
Determine the appropriate type checker function and run each Determine the appropriate type checker function and run each
@ -637,7 +567,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex
:param argument_spec: Argument spec :param argument_spec: Argument spec
:type argument_spec: dict :type argument_spec: dict
:param parameters: Parameters passed to module :param parameters: Parameters
:type parameters: dict :type parameters: dict
:param prefix: Name of the parent key that contains the spec. Used in the error message :param prefix: Name of the parent key that contains the spec. Used in the error message
@ -653,7 +583,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex
""" """
if errors is None: if errors is None:
errors = [] errors = AnsibleValidationErrorMultiple()
for param, spec in argument_spec.items(): for param, spec in argument_spec.items():
if param not in parameters: if param not in parameters:
@ -664,7 +594,7 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex
continue continue
wanted_type = spec.get('type') wanted_type = spec.get('type')
type_checker, wanted_name = get_type_validator(wanted_type) type_checker, wanted_name = _get_type_validator(wanted_type)
# Get param name for strings so we can later display this value in a useful error message if needed # Get param name for strings so we can later display this value in a useful error message if needed
# Only pass 'kwargs' to our checkers and ignore custom callable checkers # Only pass 'kwargs' to our checkers and ignore custom callable checkers
kwargs = {} kwargs = {}
@ -685,22 +615,22 @@ def validate_argument_types(argument_spec, parameters, prefix='', options_contex
if options_context: if options_context:
msg += " found in '%s'." % " -> ".join(options_context) msg += " found in '%s'." % " -> ".join(options_context)
msg += ", elements value check is supported only with 'list' type" msg += ", elements value check is supported only with 'list' type"
errors.append(msg) errors.append(ArgumentTypeError(msg))
parameters[param] = validate_elements(elements_wanted_type, param, elements, options_context, errors) parameters[param] = _validate_elements(elements_wanted_type, param, elements, options_context, errors)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
msg = "argument '%s' is of type %s" % (param, type(value)) msg = "argument '%s' is of type %s" % (param, type(value))
if options_context: if options_context:
msg += " found in '%s'." % " -> ".join(options_context) msg += " found in '%s'." % " -> ".join(options_context)
msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e))
errors.append(msg) errors.append(ArgumentTypeError(msg))
def validate_argument_values(argument_spec, parameters, options_context=None, errors=None): def _validate_argument_values(argument_spec, parameters, options_context=None, errors=None):
"""Ensure all arguments have the requested values, and there are no stray arguments""" """Ensure all arguments have the requested values, and there are no stray arguments"""
if errors is None: if errors is None:
errors = [] errors = AnsibleValidationErrorMultiple()
for param, spec in argument_spec.items(): for param, spec in argument_spec.items():
choices = spec.get('choices') choices = spec.get('choices')
@ -716,8 +646,8 @@ def validate_argument_values(argument_spec, parameters, options_context=None, er
choices_str = ", ".join([to_native(c) for c in choices]) choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_list) msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_list)
if options_context: if options_context:
msg += " found in %s" % " -> ".join(options_context) msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
errors.append(msg) errors.append(ArgumentValueError(msg))
elif parameters[param] not in choices: elif parameters[param] not in choices:
# PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking
# the value. If we can't figure this out, module author is responsible. # the value. If we can't figure this out, module author is responsible.
@ -740,23 +670,23 @@ def validate_argument_values(argument_spec, parameters, options_context=None, er
choices_str = ", ".join([to_native(c) for c in choices]) choices_str = ", ".join([to_native(c) for c in choices])
msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param]) msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param])
if options_context: if options_context:
msg += " found in %s" % " -> ".join(options_context) msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
errors.append(msg) errors.append(ArgumentValueError(msg))
else: else:
msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices) msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices)
if options_context: if options_context:
msg += " found in %s" % " -> ".join(options_context) msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
errors.append(msg) errors.append(ArgumentTypeError(msg))
def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None): def _validate_sub_spec(argument_spec, parameters, prefix='', options_context=None, errors=None, no_log_values=None, unsupported_parameters=None):
"""Validate sub argument spec. This function is recursive.""" """Validate sub argument spec. This function is recursive."""
if options_context is None: if options_context is None:
options_context = [] options_context = []
if errors is None: if errors is None:
errors = [] errors = AnsibleValidationErrorMultiple()
if no_log_values is None: if no_log_values is None:
no_log_values = set() no_log_values = set()
@ -766,11 +696,11 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None
for param, value in argument_spec.items(): for param, value in argument_spec.items():
wanted = value.get('type') wanted = value.get('type')
if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == dict): if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == 'dict'):
sub_spec = value.get('options') sub_spec = value.get('options')
if value.get('apply_defaults', False): if value.get('apply_defaults', False):
if sub_spec is not None: if sub_spec is not None:
if parameters.get(value) is None: if parameters.get(param) is None:
parameters[param] = {} parameters[param] = {}
else: else:
continue continue
@ -788,7 +718,7 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None
for idx, sub_parameters in enumerate(elements): for idx, sub_parameters in enumerate(elements):
if not isinstance(sub_parameters, dict): if not isinstance(sub_parameters, dict):
errors.append("value of '%s' must be of type dict or list of dicts" % param) errors.append(SubParameterTypeError("value of '%s' must be of type dict or list of dicts" % param))
# Set prefix for warning messages # Set prefix for warning messages
new_prefix = prefix + param new_prefix = prefix + param
@ -799,53 +729,159 @@ def validate_sub_spec(argument_spec, parameters, prefix='', options_context=None
no_log_values.update(set_fallbacks(sub_spec, sub_parameters)) no_log_values.update(set_fallbacks(sub_spec, sub_parameters))
alias_warnings = [] alias_warnings = []
alias_deprecations = []
try: try:
options_aliases, legal_inputs = handle_aliases(sub_spec, sub_parameters, alias_warnings) options_aliases = _handle_aliases(sub_spec, sub_parameters, alias_warnings, alias_deprecations)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
options_aliases = {} options_aliases = {}
legal_inputs = None errors.append(AliasError(to_native(e)))
errors.append(to_native(e))
for option, alias in alias_warnings: for option, alias in alias_warnings:
warn('Both option %s and its alias %s are set.' % (option, alias)) warn('Both option %s and its alias %s are set.' % (option, alias))
no_log_values.update(list_no_log_values(sub_spec, sub_parameters)) try:
no_log_values.update(_list_no_log_values(sub_spec, sub_parameters))
except TypeError as te:
errors.append(NoLogError(to_native(te)))
if legal_inputs is None: legal_inputs = _get_legal_inputs(sub_spec, sub_parameters, options_aliases)
legal_inputs = list(options_aliases.keys()) + list(sub_spec.keys()) unsupported_parameters.update(_get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs, options_context))
unsupported_parameters.update(get_unsupported_parameters(sub_spec, sub_parameters, legal_inputs))
try: try:
check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters) check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters, options_context)
except TypeError as e: except TypeError as e:
errors.append(to_native(e)) errors.append(MutuallyExclusiveError(to_native(e)))
no_log_values.update(set_defaults(sub_spec, sub_parameters, False)) no_log_values.update(_set_defaults(sub_spec, sub_parameters, False))
try: try:
check_required_arguments(sub_spec, sub_parameters) check_required_arguments(sub_spec, sub_parameters, options_context)
except TypeError as e: except TypeError as e:
errors.append(to_native(e)) errors.append(RequiredError(to_native(e)))
validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors)
validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors)
checks = [ _validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors)
(check_required_together, 'required_together'), _validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors)
(check_required_one_of, 'required_one_of'),
(check_required_if, 'required_if'),
(check_required_by, 'required_by'),
]
for check in checks: for check in _ADDITIONAL_CHECKS:
try: try:
check[0](value.get(check[1]), parameters) check['func'](value.get(check['attr']), sub_parameters, options_context)
except TypeError as e: except TypeError as e:
errors.append(to_native(e)) errors.append(check['err'](to_native(e)))
no_log_values.update(set_defaults(sub_spec, sub_parameters)) no_log_values.update(_set_defaults(sub_spec, sub_parameters))
# Handle nested specs # Handle nested specs
validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters) _validate_sub_spec(sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, unsupported_parameters)
options_context.pop() options_context.pop()
def env_fallback(*args, **kwargs):
"""Load value from environment variable"""
for arg in args:
if arg in os.environ:
return os.environ[arg]
raise AnsibleFallbackNotFound
def set_fallbacks(argument_spec, parameters):
no_log_values = set()
for param, value in argument_spec.items():
fallback = value.get('fallback', (None,))
fallback_strategy = fallback[0]
fallback_args = []
fallback_kwargs = {}
if param not in parameters and fallback_strategy is not None:
for item in fallback[1:]:
if isinstance(item, dict):
fallback_kwargs = item
else:
fallback_args = item
try:
fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs)
except AnsibleFallbackNotFound:
continue
else:
if value.get('no_log', False) and fallback_value:
no_log_values.add(fallback_value)
parameters[param] = fallback_value
return no_log_values
def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()):
""" Sanitize the keys in a container object by removing no_log values from key names.
This is a companion function to the `remove_values()` function. Similar to that function,
we make use of deferred_removals to avoid hitting maximum recursion depth in cases of
large data structures.
:param obj: The container object to sanitize. Non-container objects are returned unmodified.
:param no_log_strings: A set of string values we do not want logged.
:param ignore_keys: A set of string values of keys to not sanitize.
:returns: An object with sanitized keys.
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
if old_key in ignore_keys or old_key.startswith('_ansible'):
new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
# Sanitize the old key. We take advantage of the sanitizing code in
# _remove_values_conditions() rather than recreating it here.
new_key = _remove_values_conditions(old_key, no_log_strings, None)
new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals)
else:
for elem in old_data:
new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from keys')
return new_value
def remove_values(value, no_log_strings):
""" Remove strings in no_log_strings from value. If value is a container
type, then remove a lot more.
Use of deferred_removals exists, rather than a pure recursive solution,
because of the potential to hit the maximum recursion depth when dealing with
large amounts of data (see issue #24560).
"""
deferred_removals = deque()
no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings]
new_value = _remove_values_conditions(value, no_log_strings, deferred_removals)
while deferred_removals:
old_data, new_data = deferred_removals.popleft()
if isinstance(new_data, Mapping):
for old_key, old_elem in old_data.items():
new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals)
new_data[old_key] = new_elem
else:
for elem in old_data:
new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals)
if isinstance(new_data, MutableSequence):
new_data.append(new_elem)
elif isinstance(new_data, MutableSet):
new_data.add(new_elem)
else:
raise TypeError('Unknown container type encountered when removing private values from output')
return new_value

@ -39,7 +39,35 @@ def count_terms(terms, parameters):
return len(set(terms).intersection(parameters)) return len(set(terms).intersection(parameters))
def check_mutually_exclusive(terms, parameters): def safe_eval(value, locals=None, include_exceptions=False):
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
def check_mutually_exclusive(terms, parameters, options_context=None):
"""Check mutually exclusive terms against argument parameters """Check mutually exclusive terms against argument parameters
Accepts a single list or list of lists that are groups of terms that should be Accepts a single list or list of lists that are groups of terms that should be
@ -63,12 +91,14 @@ def check_mutually_exclusive(terms, parameters):
if results: if results:
full_list = ['|'.join(check) for check in results] full_list = ['|'.join(check) for check in results]
msg = "parameters are mutually exclusive: %s" % ', '.join(full_list) msg = "parameters are mutually exclusive: %s" % ', '.join(full_list)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return results return results
def check_required_one_of(terms, parameters): def check_required_one_of(terms, parameters, options_context=None):
"""Check each list of terms to ensure at least one exists in the given module """Check each list of terms to ensure at least one exists in the given module
parameters parameters
@ -93,12 +123,14 @@ def check_required_one_of(terms, parameters):
if results: if results:
for term in results: for term in results:
msg = "one of the following is required: %s" % ', '.join(term) msg = "one of the following is required: %s" % ', '.join(term)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return results return results
def check_required_together(terms, parameters): def check_required_together(terms, parameters, options_context=None):
"""Check each list of terms to ensure every parameter in each list exists """Check each list of terms to ensure every parameter in each list exists
in the given parameters in the given parameters
@ -125,12 +157,14 @@ def check_required_together(terms, parameters):
if results: if results:
for term in results: for term in results:
msg = "parameters are required together: %s" % ', '.join(term) msg = "parameters are required together: %s" % ', '.join(term)
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return results return results
def check_required_by(requirements, parameters): def check_required_by(requirements, parameters, options_context=None):
"""For each key in requirements, check the corresponding list to see if they """For each key in requirements, check the corresponding list to see if they
exist in parameters exist in parameters
@ -161,12 +195,14 @@ def check_required_by(requirements, parameters):
for key, missing in result.items(): for key, missing in result.items():
if len(missing) > 0: if len(missing) > 0:
msg = "missing parameter(s) required by '%s': %s" % (key, ', '.join(missing)) msg = "missing parameter(s) required by '%s': %s" % (key, ', '.join(missing))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return result return result
def check_required_arguments(argument_spec, parameters): def check_required_arguments(argument_spec, parameters, options_context=None):
"""Check all paramaters in argument_spec and return a list of parameters """Check all paramaters in argument_spec and return a list of parameters
that are required but not present in parameters that are required but not present in parameters
@ -190,12 +226,14 @@ def check_required_arguments(argument_spec, parameters):
if missing: if missing:
msg = "missing required arguments: %s" % ", ".join(sorted(missing)) msg = "missing required arguments: %s" % ", ".join(sorted(missing))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return missing return missing
def check_required_if(requirements, parameters): def check_required_if(requirements, parameters, options_context=None):
"""Check parameters that are conditionally required """Check parameters that are conditionally required
Raises TypeError if the check fails Raises TypeError if the check fails
@ -272,6 +310,8 @@ def check_required_if(requirements, parameters):
for missing in results: for missing in results:
msg = "%s is %s but %s of the following are missing: %s" % ( msg = "%s is %s but %s of the following are missing: %s" % (
missing['parameter'], missing['value'], missing['requires'], ', '.join(missing['missing'])) missing['parameter'], missing['value'], missing['requires'], ', '.join(missing['missing']))
if options_context:
msg = "{0} found in {1}".format(msg, " -> ".join(options_context))
raise TypeError(to_native(msg)) raise TypeError(to_native(msg))
return results return results
@ -304,34 +344,6 @@ def check_missing_parameters(parameters, required_parameters=None):
return missing_params return missing_params
def safe_eval(value, locals=None, include_exceptions=False):
# do not allow method calls to modules
if not isinstance(value, string_types):
# already templated to a datavaluestructure, perhaps?
if include_exceptions:
return (value, None)
return value
if re.search(r'\w\.\w+\(', value):
if include_exceptions:
return (value, None)
return value
# do not allow imports
if re.search(r'import \w+', value):
if include_exceptions:
return (value, None)
return value
try:
result = literal_eval(value)
if include_exceptions:
return (result, None)
else:
return result
except Exception as e:
if include_exceptions:
return (value, e)
return value
# FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string() # FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string()
# which is using those for the warning messaged based on string conversion warning settings. # which is using those for the warning messaged based on string conversion warning settings.
# Not sure how to deal with that here since we don't have config state to query. # Not sure how to deal with that here since we don't have config state to query.

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
class AnsibleFallbackNotFound(Exception):
"""Fallback validator was not found"""
class AnsibleValidationError(Exception):
"""Single argument spec validation error"""
def __init__(self, message):
super(AnsibleValidationError, self).__init__(message)
self.error_message = message
@property
def msg(self):
return self.args[0]
class AnsibleValidationErrorMultiple(AnsibleValidationError):
"""Multiple argument spec validation errors"""
def __init__(self, errors=None):
self.errors = errors[:] if errors else []
def __getitem__(self, key):
return self.errors[key]
def __setitem__(self, key, value):
self.errors[key] = value
def __delitem__(self, key):
del self.errors[key]
@property
def msg(self):
return self.errors[0].args[0]
@property
def messages(self):
return [err.msg for err in self.errors]
def append(self, error):
self.errors.append(error)
def extend(self, errors):
self.errors.extend(errors)
class AliasError(AnsibleValidationError):
"""Error handling aliases"""
class ArgumentTypeError(AnsibleValidationError):
"""Error with parameter type"""
class ArgumentValueError(AnsibleValidationError):
"""Error with parameter value"""
class ElementError(AnsibleValidationError):
"""Error when validating elements"""
class MutuallyExclusiveError(AnsibleValidationError):
"""Mutually exclusive parameters were supplied"""
class NoLogError(AnsibleValidationError):
"""Error converting no_log values"""
class RequiredByError(AnsibleValidationError):
"""Error with parameters that are required by other parameters"""
class RequiredDefaultError(AnsibleValidationError):
"""A required parameter was assigned a default value"""
class RequiredError(AnsibleValidationError):
"""Missing a required parameter"""
class RequiredIfError(AnsibleValidationError):
"""Error with conditionally required parameters"""
class RequiredOneOfError(AnsibleValidationError):
"""Error with parameters where at least one is required"""
class RequiredTogetherError(AnsibleValidationError):
"""Error with parameters that are required together"""
class SubParameterTypeError(AnsibleValidationError):
"""Incorrect type for subparameter"""
class UnsupportedError(AnsibleValidationError):
"""Unsupported parameters were supplied"""

@ -8,6 +8,7 @@ from ansible.errors import AnsibleError
from ansible.plugins.action import ActionBase from ansible.plugins.action import ActionBase
from ansible.module_utils.six import iteritems, string_types from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
from ansible.module_utils.errors import AnsibleValidationErrorMultiple
class ActionModule(ActionBase): class ActionModule(ActionBase):
@ -82,13 +83,14 @@ class ActionModule(ActionBase):
args_from_vars = self.get_args_from_task_vars(argument_spec_data, task_vars) args_from_vars = self.get_args_from_task_vars(argument_spec_data, task_vars)
provided_arguments.update(args_from_vars) provided_arguments.update(args_from_vars)
validator = ArgumentSpecValidator(argument_spec_data, provided_arguments) validator = ArgumentSpecValidator(argument_spec_data)
validation_result = validator.validate(provided_arguments)
if not validator.validate(): if validation_result.error_messages:
result['failed'] = True result['failed'] = True
result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validator.error_messages) result['msg'] = 'Validation of arguments failed:\n%s' % '\n'.join(validation_result.error_messages)
result['argument_spec_data'] = argument_spec_data result['argument_spec_data'] = argument_spec_data
result['argument_errors'] = validator.error_messages result['argument_errors'] = validation_result.error_messages
return result return result
result['changed'] = False result['changed'] = False

@ -29,7 +29,6 @@ from io import BytesIO
import ansible.errors import ansible.errors
from ansible.executor.module_common import recursive_finder from ansible.executor.module_common import recursive_finder
from ansible.module_utils.six import PY2
# These are the modules that are brought in by module_utils/basic.py This may need to be updated # These are the modules that are brought in by module_utils/basic.py This may need to be updated
@ -58,12 +57,14 @@ MODULE_UTILS_BASIC_FILES = frozenset(('ansible/__init__.py',
'ansible/module_utils/common/text/formatters.py', 'ansible/module_utils/common/text/formatters.py',
'ansible/module_utils/common/validation.py', 'ansible/module_utils/common/validation.py',
'ansible/module_utils/common/_utils.py', 'ansible/module_utils/common/_utils.py',
'ansible/module_utils/common/arg_spec.py',
'ansible/module_utils/compat/__init__.py', 'ansible/module_utils/compat/__init__.py',
'ansible/module_utils/compat/_selectors2.py', 'ansible/module_utils/compat/_selectors2.py',
'ansible/module_utils/compat/selectors.py', 'ansible/module_utils/compat/selectors.py',
'ansible/module_utils/compat/selinux.py', 'ansible/module_utils/compat/selinux.py',
'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/__init__.py',
'ansible/module_utils/distro/_distro.py', 'ansible/module_utils/distro/_distro.py',
'ansible/module_utils/errors.py',
'ansible/module_utils/parsing/__init__.py', 'ansible/module_utils/parsing/__init__.py',
'ansible/module_utils/parsing/convert_bool.py', 'ansible/module_utils/parsing/convert_bool.py',
'ansible/module_utils/pycompat24.py', 'ansible/module_utils/pycompat24.py',

@ -84,9 +84,9 @@ INVALID_SPECS = (
({'arg': {'type': 'list', 'elements': MOCK_VALIDATOR_FAIL}}, {'arg': [1, "bad"]}, "bad conversion"), ({'arg': {'type': 'list', 'elements': MOCK_VALIDATOR_FAIL}}, {'arg': [1, "bad"]}, "bad conversion"),
# unknown parameter # unknown parameter
({'arg': {'type': 'int'}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'}, ({'arg': {'type': 'int'}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'},
'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg'), 'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg.'),
({'arg': {'type': 'int', 'aliases': ['argument']}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'}, ({'arg': {'type': 'int', 'aliases': ['argument']}}, {'other': 'bad', '_ansible_module_name': 'ansible_unittest'},
'Unsupported parameters for (ansible_unittest) module: other Supported parameters include: arg (argument)'), 'Unsupported parameters for (ansible_unittest) module: other. Supported parameters include: arg (argument).'),
# parameter is required # parameter is required
({'arg': {'required': True}}, {}, 'missing required arguments: arg'), ({'arg': {'required': True}}, {}, 'missing required arguments: arg'),
) )
@ -496,7 +496,7 @@ class TestComplexOptions:
# Missing required option # Missing required option
({'foobar': [{}]}, 'missing required arguments: foo found in foobar'), ({'foobar': [{}]}, 'missing required arguments: foo found in foobar'),
# Invalid option # Invalid option
({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: invalid found in foobar. Supported parameters include'), ({'foobar': [{"foo": "hello", "bam": "good", "invalid": "bad"}]}, 'module: foobar.invalid. Supported parameters include'),
# Mutually exclusive options found # Mutually exclusive options found
({'foobar': [{"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}]}, ({'foobar': [{"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}]},
'parameters are mutually exclusive: bam|bam1 found in foobar'), 'parameters are mutually exclusive: bam|bam1 found in foobar'),
@ -520,7 +520,7 @@ class TestComplexOptions:
({'foobar': {}}, 'missing required arguments: foo found in foobar'), ({'foobar': {}}, 'missing required arguments: foo found in foobar'),
# Invalid option # Invalid option
({'foobar': {"foo": "hello", "bam": "good", "invalid": "bad"}}, ({'foobar': {"foo": "hello", "bam": "good", "invalid": "bad"}},
'module: invalid found in foobar. Supported parameters include'), 'module: foobar.invalid. Supported parameters include'),
# Mutually exclusive options found # Mutually exclusive options found
({'foobar': {"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}}, ({'foobar': {"foo": "test", "bam": "bad", "bam1": "bad", "baz": "req_to"}},
'parameters are mutually exclusive: bam|bam1 found in foobar'), 'parameters are mutually exclusive: bam|bam1 found in foobar'),

@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator
def test_add_sequence():
v = ArgumentSpecValidator({}, {})
errors = [
'one error',
'another error',
]
v._add_error(errors)
assert v.error_messages == errors
def test_invalid_error_message():
v = ArgumentSpecValidator({}, {})
with pytest.raises(ValueError, match="Error messages must be a string or sequence not a"):
v._add_error(None)

@ -7,10 +7,11 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.errors import AnsibleValidationError, AnsibleValidationErrorMultiple
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages from ansible.module_utils.common.warnings import get_deprecation_messages, get_warning_messages
# id, argument spec, parameters, expected parameters, expected pass/fail, error, deprecation, warning # id, argument spec, parameters, expected parameters, deprecation, warning
ALIAS_TEST_CASES = [ ALIAS_TEST_CASES = [
( (
"alias", "alias",
@ -20,29 +21,6 @@ ALIAS_TEST_CASES = [
'dir': '/tmp', 'dir': '/tmp',
'path': '/tmp', 'path': '/tmp',
}, },
True,
"",
"",
"",
),
(
"alias-invalid",
{'path': {'aliases': 'bad'}},
{},
{'path': None},
False,
"internal error: aliases must be a list or tuple",
"",
"",
),
(
# This isn't related to aliases, but it exists in the alias handling code
"default-and-required",
{'name': {'default': 'ray', 'required': True}},
{},
{'name': 'ray'},
False,
"internal error: required and default are mutually exclusive for name",
"", "",
"", "",
), ),
@ -58,10 +36,8 @@ ALIAS_TEST_CASES = [
'directory': '/tmp', 'directory': '/tmp',
'path': '/tmp', 'path': '/tmp',
}, },
True,
"",
"", "",
"Both option path and its alias directory are set", {'alias': 'directory', 'option': 'path'},
), ),
( (
"deprecated-alias", "deprecated-alias",
@ -81,39 +57,66 @@ ALIAS_TEST_CASES = [
'path': '/tmp', 'path': '/tmp',
'not_yo_path': '/tmp', 'not_yo_path': '/tmp',
}, },
True, {'version': '1.7', 'date': None, 'collection_name': None, 'name': 'not_yo_path'},
"",
"Alias 'not_yo_path' is deprecated.",
"", "",
) )
] ]
# id, argument spec, parameters, expected parameters, error
ALIAS_TEST_CASES_INVALID = [
(
"alias-invalid",
{'path': {'aliases': 'bad'}},
{},
{'path': None},
"internal error: aliases must be a list or tuple",
),
(
# This isn't related to aliases, but it exists in the alias handling code
"default-and-required",
{'name': {'default': 'ray', 'required': True}},
{},
{'name': 'ray'},
"internal error: required and default are mutually exclusive for name",
),
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected', 'passfail', 'error', 'deprecation', 'warning'), ('arg_spec', 'parameters', 'expected', 'deprecation', 'warning'),
((i[1], i[2], i[3], i[4], i[5], i[6], i[7]) for i in ALIAS_TEST_CASES), ((i[1:]) for i in ALIAS_TEST_CASES),
ids=[i[0] for i in ALIAS_TEST_CASES] ids=[i[0] for i in ALIAS_TEST_CASES]
) )
def test_aliases(arg_spec, parameters, expected, passfail, error, deprecation, warning): def test_aliases(arg_spec, parameters, expected, deprecation, warning):
v = ArgumentSpecValidator(arg_spec, parameters) v = ArgumentSpecValidator(arg_spec)
passed = v.validate() result = v.validate(parameters)
assert passed is passfail assert isinstance(result, ValidationResult)
assert v.validated_parameters == expected assert result.validated_parameters == expected
assert result.error_messages == []
if not error: if deprecation:
assert v.error_messages == [] assert deprecation == result._deprecations[0]
else: else:
assert error in v.error_messages[0] assert result._deprecations == []
deprecations = get_deprecation_messages() if warning:
if not deprecations: assert warning == result._warnings[0]
assert deprecations == ()
else: else:
assert deprecation in get_deprecation_messages()[0]['msg'] assert result._warnings == []
warnings = get_warning_messages()
if not warning: @pytest.mark.parametrize(
assert warnings == () ('arg_spec', 'parameters', 'expected', 'error'),
else: ((i[1:]) for i in ALIAS_TEST_CASES_INVALID),
assert warning in warnings[0] ids=[i[0] for i in ALIAS_TEST_CASES_INVALID]
)
def test_aliases_invalid(arg_spec, parameters, expected, error):
v = ArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert isinstance(result, ValidationResult)
assert error in result.error_messages
assert isinstance(result.errors.errors[0], AnsibleValidationError)
assert isinstance(result.errors, AnsibleValidationErrorMultiple)

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2021 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import ansible.module_utils.common.warnings as warnings
from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator, ValidationResult
def test_module_validate():
arg_spec = {'name': {}}
parameters = {'name': 'larry'}
expected = {'name': 'larry'}
v = ModuleArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert isinstance(result, ValidationResult)
assert result.error_messages == []
assert result._deprecations == []
assert result._warnings == []
assert result.validated_parameters == expected
def test_module_alias_deprecations_warnings():
arg_spec = {
'path': {
'aliases': ['source', 'src', 'flamethrower'],
'deprecated_aliases': [{'name': 'flamethrower', 'date': '2020-03-04'}],
},
}
parameters = {'flamethrower': '/tmp', 'source': '/tmp'}
expected = {
'path': '/tmp',
'flamethrower': '/tmp',
'source': '/tmp',
}
v = ModuleArgumentSpecValidator(arg_spec)
result = v.validate(parameters)
assert result.validated_parameters == expected
assert result._deprecations == [
{
'collection_name': None,
'date': '2020-03-04',
'name': 'flamethrower',
'version': None,
}
]
assert "Alias 'flamethrower' is deprecated" in warnings._global_deprecations[0]['msg']
assert result._warnings == [{'alias': 'flamethrower', 'option': 'path'}]
assert "Both option path and its alias flamethrower are set" in warnings._global_warnings[0]

@ -5,7 +5,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
def test_sub_spec(): def test_sub_spec():
@ -39,12 +39,12 @@ def test_sub_spec():
} }
} }
v = ArgumentSpecValidator(arg_spec, parameters) v = ArgumentSpecValidator(arg_spec)
passed = v.validate() result = v.validate(parameters)
assert passed is True assert isinstance(result, ValidationResult)
assert v.error_messages == [] assert result.validated_parameters == expected
assert v.validated_parameters == expected assert result.error_messages == []
def test_nested_sub_spec(): def test_nested_sub_spec():
@ -98,9 +98,9 @@ def test_nested_sub_spec():
} }
} }
v = ArgumentSpecValidator(arg_spec, parameters) v = ArgumentSpecValidator(arg_spec)
passed = v.validate() result = v.validate(parameters)
assert passed is True assert isinstance(result, ValidationResult)
assert v.error_messages == [] assert result.validated_parameters == expected
assert v.validated_parameters == expected assert result.error_messages == []

@ -7,17 +7,19 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
from ansible.module_utils.errors import AnsibleValidationErrorMultiple
from ansible.module_utils.six import PY2 from ansible.module_utils.six import PY2
# Each item is id, argument_spec, parameters, expected, error test string # Each item is id, argument_spec, parameters, expected, unsupported parameters, error test string
INVALID_SPECS = [ INVALID_SPECS = [
( (
'invalid-list', 'invalid-list',
{'packages': {'type': 'list'}}, {'packages': {'type': 'list'}},
{'packages': {'key': 'value'}}, {'packages': {'key': 'value'}},
{'packages': {'key': 'value'}}, {'packages': {'key': 'value'}},
set(),
"unable to convert to list: <class 'dict'> cannot be converted to a list", "unable to convert to list: <class 'dict'> cannot be converted to a list",
), ),
( (
@ -25,6 +27,7 @@ INVALID_SPECS = [
{'users': {'type': 'dict'}}, {'users': {'type': 'dict'}},
{'users': ['one', 'two']}, {'users': ['one', 'two']},
{'users': ['one', 'two']}, {'users': ['one', 'two']},
set(),
"unable to convert to dict: <class 'list'> cannot be converted to a dict", "unable to convert to dict: <class 'list'> cannot be converted to a dict",
), ),
( (
@ -32,6 +35,7 @@ INVALID_SPECS = [
{'bool': {'type': 'bool'}}, {'bool': {'type': 'bool'}},
{'bool': {'k': 'v'}}, {'bool': {'k': 'v'}},
{'bool': {'k': 'v'}}, {'bool': {'k': 'v'}},
set(),
"unable to convert to bool: <class 'dict'> cannot be converted to a bool", "unable to convert to bool: <class 'dict'> cannot be converted to a bool",
), ),
( (
@ -39,6 +43,7 @@ INVALID_SPECS = [
{'float': {'type': 'float'}}, {'float': {'type': 'float'}},
{'float': 'hello'}, {'float': 'hello'},
{'float': 'hello'}, {'float': 'hello'},
set(),
"unable to convert to float: <class 'str'> cannot be converted to a float", "unable to convert to float: <class 'str'> cannot be converted to a float",
), ),
( (
@ -46,6 +51,7 @@ INVALID_SPECS = [
{'bytes': {'type': 'bytes'}}, {'bytes': {'type': 'bytes'}},
{'bytes': 'one'}, {'bytes': 'one'},
{'bytes': 'one'}, {'bytes': 'one'},
set(),
"unable to convert to bytes: <class 'str'> cannot be converted to a Byte value", "unable to convert to bytes: <class 'str'> cannot be converted to a Byte value",
), ),
( (
@ -53,6 +59,7 @@ INVALID_SPECS = [
{'bits': {'type': 'bits'}}, {'bits': {'type': 'bits'}},
{'bits': 'one'}, {'bits': 'one'},
{'bits': 'one'}, {'bits': 'one'},
set(),
"unable to convert to bits: <class 'str'> cannot be converted to a Bit value", "unable to convert to bits: <class 'str'> cannot be converted to a Bit value",
), ),
( (
@ -60,6 +67,7 @@ INVALID_SPECS = [
{'some_json': {'type': 'jsonarg'}}, {'some_json': {'type': 'jsonarg'}},
{'some_json': set()}, {'some_json': set()},
{'some_json': set()}, {'some_json': set()},
set(),
"unable to convert to jsonarg: <class 'set'> cannot be converted to a json string", "unable to convert to jsonarg: <class 'set'> cannot be converted to a json string",
), ),
( (
@ -74,13 +82,15 @@ INVALID_SPECS = [
'badparam': '', 'badparam': '',
'another': '', 'another': '',
}, },
"Unsupported parameters: another, badparam", set(('another', 'badparam')),
"another, badparam. Supported parameters include: name.",
), ),
( (
'invalid-elements', 'invalid-elements',
{'numbers': {'type': 'list', 'elements': 'int'}}, {'numbers': {'type': 'list', 'elements': 'int'}},
{'numbers': [55, 33, 34, {'key': 'value'}]}, {'numbers': [55, 33, 34, {'key': 'value'}]},
{'numbers': [55, 33, 34]}, {'numbers': [55, 33, 34]},
set(),
"Elements value for option 'numbers' is of type <class 'dict'> and we were unable to convert to int: <class 'dict'> cannot be converted to an int" "Elements value for option 'numbers' is of type <class 'dict'> and we were unable to convert to int: <class 'dict'> cannot be converted to an int"
), ),
( (
@ -88,23 +98,29 @@ INVALID_SPECS = [
{'req': {'required': True}}, {'req': {'required': True}},
{}, {},
{'req': None}, {'req': None},
set(),
"missing required arguments: req" "missing required arguments: req"
) )
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected', 'error'), ('arg_spec', 'parameters', 'expected', 'unsupported', 'error'),
((i[1], i[2], i[3], i[4]) for i in INVALID_SPECS), (i[1:] for i in INVALID_SPECS),
ids=[i[0] for i in INVALID_SPECS] ids=[i[0] for i in INVALID_SPECS]
) )
def test_invalid_spec(arg_spec, parameters, expected, error): def test_invalid_spec(arg_spec, parameters, expected, unsupported, error):
v = ArgumentSpecValidator(arg_spec, parameters) v = ArgumentSpecValidator(arg_spec)
passed = v.validate() result = v.validate(parameters)
with pytest.raises(AnsibleValidationErrorMultiple) as exc_info:
raise result.errors
if PY2: if PY2:
error = error.replace('class', 'type') error = error.replace('class', 'type')
assert error in v.error_messages[0] assert isinstance(result, ValidationResult)
assert v.validated_parameters == expected assert error in exc_info.value.msg
assert passed is False assert error in result.error_messages[0]
assert result.unsupported_parameters == unsupported
assert result.validated_parameters == expected

@ -7,45 +7,53 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.arg_spec import ArgumentSpecValidator import ansible.module_utils.common.warnings as warnings
# Each item is id, argument_spec, parameters, expected from ansible.module_utils.common.arg_spec import ArgumentSpecValidator, ValidationResult
# Each item is id, argument_spec, parameters, expected, valid parameter names
VALID_SPECS = [ VALID_SPECS = [
( (
'str-no-type-specified', 'str-no-type-specified',
{'name': {}}, {'name': {}},
{'name': 'rey'}, {'name': 'rey'},
{'name': 'rey'}, {'name': 'rey'},
set(('name',)),
), ),
( (
'str', 'str',
{'name': {'type': 'str'}}, {'name': {'type': 'str'}},
{'name': 'rey'}, {'name': 'rey'},
{'name': 'rey'}, {'name': 'rey'},
set(('name',)),
), ),
( (
'str-convert', 'str-convert',
{'name': {'type': 'str'}}, {'name': {'type': 'str'}},
{'name': 5}, {'name': 5},
{'name': '5'}, {'name': '5'},
set(('name',)),
), ),
( (
'list', 'list',
{'packages': {'type': 'list'}}, {'packages': {'type': 'list'}},
{'packages': ['vim', 'python']}, {'packages': ['vim', 'python']},
{'packages': ['vim', 'python']}, {'packages': ['vim', 'python']},
set(('packages',)),
), ),
( (
'list-comma-string', 'list-comma-string',
{'packages': {'type': 'list'}}, {'packages': {'type': 'list'}},
{'packages': 'vim,python'}, {'packages': 'vim,python'},
{'packages': ['vim', 'python']}, {'packages': ['vim', 'python']},
set(('packages',)),
), ),
( (
'list-comma-string-space', 'list-comma-string-space',
{'packages': {'type': 'list'}}, {'packages': {'type': 'list'}},
{'packages': 'vim, python'}, {'packages': 'vim, python'},
{'packages': ['vim', ' python']}, {'packages': ['vim', ' python']},
set(('packages',)),
), ),
( (
'dict', 'dict',
@ -64,6 +72,7 @@ VALID_SPECS = [
'last': 'skywalker', 'last': 'skywalker',
} }
}, },
set(('user',)),
), ),
( (
'dict-k=v', 'dict-k=v',
@ -76,6 +85,7 @@ VALID_SPECS = [
'last': 'skywalker', 'last': 'skywalker',
} }
}, },
set(('user',)),
), ),
( (
'dict-k=v-spaces', 'dict-k=v-spaces',
@ -88,6 +98,7 @@ VALID_SPECS = [
'last': 'skywalker', 'last': 'skywalker',
} }
}, },
set(('user',)),
), ),
( (
'bool', 'bool',
@ -103,6 +114,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-ints', 'bool-ints',
@ -118,6 +130,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-true-false', 'bool-true-false',
@ -133,6 +146,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-yes-no', 'bool-yes-no',
@ -148,6 +162,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-y-n', 'bool-y-n',
@ -163,6 +178,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-on-off', 'bool-on-off',
@ -178,6 +194,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-1-0', 'bool-1-0',
@ -193,6 +210,7 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'bool-float', 'bool-float',
@ -208,89 +226,112 @@ VALID_SPECS = [
'enabled': True, 'enabled': True,
'disabled': False, 'disabled': False,
}, },
set(('enabled', 'disabled')),
), ),
( (
'float', 'float',
{'digit': {'type': 'float'}}, {'digit': {'type': 'float'}},
{'digit': 3.14159}, {'digit': 3.14159},
{'digit': 3.14159}, {'digit': 3.14159},
set(('digit',)),
), ),
( (
'float-str', 'float-str',
{'digit': {'type': 'float'}}, {'digit': {'type': 'float'}},
{'digit': '3.14159'}, {'digit': '3.14159'},
{'digit': 3.14159}, {'digit': 3.14159},
set(('digit',)),
), ),
( (
'path', 'path',
{'path': {'type': 'path'}}, {'path': {'type': 'path'}},
{'path': '~/bin'}, {'path': '~/bin'},
{'path': '/home/ansible/bin'}, {'path': '/home/ansible/bin'},
set(('path',)),
), ),
( (
'raw', 'raw',
{'raw': {'type': 'raw'}}, {'raw': {'type': 'raw'}},
{'raw': 0x644}, {'raw': 0x644},
{'raw': 0x644}, {'raw': 0x644},
set(('raw',)),
), ),
( (
'bytes', 'bytes',
{'bytes': {'type': 'bytes'}}, {'bytes': {'type': 'bytes'}},
{'bytes': '2K'}, {'bytes': '2K'},
{'bytes': 2048}, {'bytes': 2048},
set(('bytes',)),
), ),
( (
'bits', 'bits',
{'bits': {'type': 'bits'}}, {'bits': {'type': 'bits'}},
{'bits': '1Mb'}, {'bits': '1Mb'},
{'bits': 1048576}, {'bits': 1048576},
set(('bits',)),
), ),
( (
'jsonarg', 'jsonarg',
{'some_json': {'type': 'jsonarg'}}, {'some_json': {'type': 'jsonarg'}},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
set(('some_json',)),
), ),
( (
'jsonarg-list', 'jsonarg-list',
{'some_json': {'type': 'jsonarg'}}, {'some_json': {'type': 'jsonarg'}},
{'some_json': ['one', 'two']}, {'some_json': ['one', 'two']},
{'some_json': '["one", "two"]'}, {'some_json': '["one", "two"]'},
set(('some_json',)),
), ),
( (
'jsonarg-dict', 'jsonarg-dict',
{'some_json': {'type': 'jsonarg'}}, {'some_json': {'type': 'jsonarg'}},
{'some_json': {"users": {"bob": {"role": "accountant"}}}}, {'some_json': {"users": {"bob": {"role": "accountant"}}}},
{'some_json': '{"users": {"bob": {"role": "accountant"}}}'}, {'some_json': '{"users": {"bob": {"role": "accountant"}}}'},
set(('some_json',)),
), ),
( (
'defaults', 'defaults',
{'param': {'default': 'DEFAULT'}}, {'param': {'default': 'DEFAULT'}},
{}, {},
{'param': 'DEFAULT'}, {'param': 'DEFAULT'},
set(('param',)),
), ),
( (
'elements', 'elements',
{'numbers': {'type': 'list', 'elements': 'int'}}, {'numbers': {'type': 'list', 'elements': 'int'}},
{'numbers': [55, 33, 34, '22']}, {'numbers': [55, 33, 34, '22']},
{'numbers': [55, 33, 34, 22]}, {'numbers': [55, 33, 34, 22]},
set(('numbers',)),
), ),
(
'aliases',
{'src': {'aliases': ['path', 'source']}},
{'src': '/tmp'},
{'src': '/tmp'},
set(('src (path, source)',)),
)
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
('arg_spec', 'parameters', 'expected'), ('arg_spec', 'parameters', 'expected', 'valid_params'),
((i[1], i[2], i[3]) for i in VALID_SPECS), (i[1:] for i in VALID_SPECS),
ids=[i[0] for i in VALID_SPECS] ids=[i[0] for i in VALID_SPECS]
) )
def test_valid_spec(arg_spec, parameters, expected, mocker): def test_valid_spec(arg_spec, parameters, expected, valid_params, mocker):
mocker.patch('ansible.module_utils.common.validation.os.path.expanduser', return_value='/home/ansible/bin') mocker.patch('ansible.module_utils.common.validation.os.path.expanduser', return_value='/home/ansible/bin')
mocker.patch('ansible.module_utils.common.validation.os.path.expandvars', return_value='/home/ansible/bin') mocker.patch('ansible.module_utils.common.validation.os.path.expandvars', return_value='/home/ansible/bin')
v = ArgumentSpecValidator(arg_spec, parameters) v = ArgumentSpecValidator(arg_spec)
passed = v.validate() result = v.validate(parameters)
assert isinstance(result, ValidationResult)
assert result.validated_parameters == expected
assert result.unsupported_parameters == set()
assert result.error_messages == []
assert v._valid_parameter_names == valid_params
assert v.validated_parameters == expected # Again to check caching
assert v.error_messages == [] assert v._valid_parameter_names == valid_params
assert passed is True

@ -8,7 +8,7 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.parameters import get_unsupported_parameters from ansible.module_utils.common.parameters import _get_unsupported_parameters
@pytest.fixture @pytest.fixture
@ -19,32 +19,6 @@ def argument_spec():
} }
def mock_handle_aliases(*args):
aliases = {}
legal_inputs = [
'_ansible_check_mode',
'_ansible_debug',
'_ansible_diff',
'_ansible_keep_remote_files',
'_ansible_module_name',
'_ansible_no_log',
'_ansible_remote_tmp',
'_ansible_selinux_special_fs',
'_ansible_shell_executable',
'_ansible_socket',
'_ansible_string_conversion_action',
'_ansible_syslog_facility',
'_ansible_tmpdir',
'_ansible_verbosity',
'_ansible_version',
'state',
'status',
'enabled',
]
return aliases, legal_inputs
@pytest.mark.parametrize( @pytest.mark.parametrize(
('module_parameters', 'legal_inputs', 'expected'), ('module_parameters', 'legal_inputs', 'expected'),
( (
@ -59,7 +33,6 @@ def mock_handle_aliases(*args):
) )
) )
def test_check_arguments(argument_spec, module_parameters, legal_inputs, expected, mocker): def test_check_arguments(argument_spec, module_parameters, legal_inputs, expected, mocker):
mocker.patch('ansible.module_utils.common.parameters.handle_aliases', side_effect=mock_handle_aliases) result = _get_unsupported_parameters(argument_spec, module_parameters, legal_inputs)
result = get_unsupported_parameters(argument_spec, module_parameters, legal_inputs)
assert result == expected assert result == expected

@ -8,27 +8,9 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.parameters import handle_aliases from ansible.module_utils.common.parameters import _handle_aliases
from ansible.module_utils._text import to_native from ansible.module_utils._text import to_native
DEFAULT_LEGAL_INPUTS = [
'_ansible_check_mode',
'_ansible_debug',
'_ansible_diff',
'_ansible_keep_remote_files',
'_ansible_module_name',
'_ansible_no_log',
'_ansible_remote_tmp',
'_ansible_selinux_special_fs',
'_ansible_shell_executable',
'_ansible_socket',
'_ansible_string_conversion_action',
'_ansible_syslog_facility',
'_ansible_tmpdir',
'_ansible_verbosity',
'_ansible_version',
]
def test_handle_aliases_no_aliases(): def test_handle_aliases_no_aliases():
argument_spec = { argument_spec = {
@ -40,14 +22,9 @@ def test_handle_aliases_no_aliases():
'path': 'bar' 'path': 'bar'
} }
expected = ( expected = {}
{}, result = _handle_aliases(argument_spec, params)
DEFAULT_LEGAL_INPUTS + ['name'],
)
expected[1].sort()
result = handle_aliases(argument_spec, params)
result[1].sort()
assert expected == result assert expected == result
@ -63,14 +40,9 @@ def test_handle_aliases_basic():
'nick': 'foo', 'nick': 'foo',
} }
expected = ( expected = {'surname': 'name', 'nick': 'name'}
{'surname': 'name', 'nick': 'name'}, result = _handle_aliases(argument_spec, params)
DEFAULT_LEGAL_INPUTS + ['name', 'surname', 'nick'],
)
expected[1].sort()
result = handle_aliases(argument_spec, params)
result[1].sort()
assert expected == result assert expected == result
@ -84,7 +56,7 @@ def test_handle_aliases_value_error():
} }
with pytest.raises(ValueError) as ve: with pytest.raises(ValueError) as ve:
handle_aliases(argument_spec, params) _handle_aliases(argument_spec, params)
assert 'internal error: aliases must be a list or tuple' == to_native(ve.error) assert 'internal error: aliases must be a list or tuple' == to_native(ve.error)
@ -98,5 +70,5 @@ def test_handle_aliases_type_error():
} }
with pytest.raises(TypeError) as te: with pytest.raises(TypeError) as te:
handle_aliases(argument_spec, params) _handle_aliases(argument_spec, params)
assert 'internal error: required and default are mutually exclusive' in to_native(te.error) assert 'internal error: required and default are mutually exclusive' in to_native(te.error)

@ -7,7 +7,7 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.parameters import list_deprecations from ansible.module_utils.common.parameters import _list_deprecations
@pytest.fixture @pytest.fixture
@ -33,7 +33,7 @@ def test_list_deprecations():
'foo': {'old': 'value'}, 'foo': {'old': 'value'},
'bar': [{'old': 'value'}, {}], 'bar': [{'old': 'value'}, {}],
} }
result = list_deprecations(argument_spec, params) result = _list_deprecations(argument_spec, params)
assert len(result) == 3 assert len(result) == 3
result.sort(key=lambda entry: entry['msg']) result.sort(key=lambda entry: entry['msg'])
assert result[0]['msg'] == """Param 'bar["old"]' is deprecated. See the module docs for more information""" assert result[0]['msg'] == """Param 'bar["old"]' is deprecated. See the module docs for more information"""

@ -7,7 +7,7 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.common.parameters import list_no_log_values from ansible.module_utils.common.parameters import _list_no_log_values
@pytest.fixture @pytest.fixture
@ -55,12 +55,12 @@ def test_list_no_log_values_no_secrets(module_parameters):
'value': {'type': 'int'}, 'value': {'type': 'int'},
} }
expected = set() expected = set()
assert expected == list_no_log_values(argument_spec, module_parameters) assert expected == _list_no_log_values(argument_spec, module_parameters)
def test_list_no_log_values(argument_spec, module_parameters): def test_list_no_log_values(argument_spec, module_parameters):
expected = set(('under', 'makeshift')) expected = set(('under', 'makeshift'))
assert expected == list_no_log_values(argument_spec(), module_parameters()) assert expected == _list_no_log_values(argument_spec(), module_parameters())
@pytest.mark.parametrize('extra_params', [ @pytest.mark.parametrize('extra_params', [
@ -81,7 +81,7 @@ def test_list_no_log_values_invalid_suboptions(argument_spec, module_parameters,
with pytest.raises(TypeError, match=r"(Value '.*?' in the sub parameter field '.*?' must by a dict, not '.*?')" with pytest.raises(TypeError, match=r"(Value '.*?' in the sub parameter field '.*?' must by a dict, not '.*?')"
r"|(dictionary requested, could not parse JSON or key=value)"): r"|(dictionary requested, could not parse JSON or key=value)"):
list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_suboptions(argument_spec, module_parameters): def test_list_no_log_values_suboptions(argument_spec, module_parameters):
@ -103,7 +103,7 @@ def test_list_no_log_values_suboptions(argument_spec, module_parameters):
} }
expected = set(('under', 'makeshift', 'bagel')) expected = set(('under', 'makeshift', 'bagel'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters): def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters):
@ -136,7 +136,7 @@ def test_list_no_log_values_sub_suboptions(argument_spec, module_parameters):
} }
expected = set(('under', 'makeshift', 'saucy', 'corporate')) expected = set(('under', 'makeshift', 'saucy', 'corporate'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_suboptions_list(argument_spec, module_parameters): def test_list_no_log_values_suboptions_list(argument_spec, module_parameters):
@ -164,7 +164,7 @@ def test_list_no_log_values_suboptions_list(argument_spec, module_parameters):
} }
expected = set(('under', 'makeshift', 'playroom', 'luxury')) expected = set(('under', 'makeshift', 'playroom', 'luxury'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters): def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters):
@ -204,7 +204,7 @@ def test_list_no_log_values_sub_suboptions_list(argument_spec, module_parameters
} }
expected = set(('under', 'makeshift', 'playroom', 'luxury', 'basis', 'gave', 'composure', 'thumping')) expected = set(('under', 'makeshift', 'playroom', 'luxury', 'basis', 'gave', 'composure', 'thumping'))
assert expected == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) assert expected == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))
@pytest.mark.parametrize('extra_params, expected', ( @pytest.mark.parametrize('extra_params, expected', (
@ -225,4 +225,4 @@ def test_string_suboptions_as_string(argument_spec, module_parameters, extra_par
result = set(('under', 'makeshift')) result = set(('under', 'makeshift'))
result.update(expected) result.update(expected)
assert result == list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params)) assert result == _list_no_log_values(argument_spec(extra_opts), module_parameters(extra_params))

Loading…
Cancel
Save