diff --git a/lib/ansible/module_utils/network_common.py b/lib/ansible/module_utils/network_common.py index 20aac99b04c..32d831b88ff 100644 --- a/lib/ansible/module_utils/network_common.py +++ b/lib/ansible/module_utils/network_common.py @@ -24,7 +24,10 @@ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +from itertools import chain +from ansible.module_utils.six import iteritems from ansible.module_utils.basic import AnsibleFallbackNotFound from ansible.module_utils.six import iteritems @@ -38,7 +41,13 @@ def to_list(val): return list() -class ComplexDict(object): +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): """Transforms a dict to with an argument spec This class will take a dict and apply an Ansible argument spec to the @@ -52,7 +61,7 @@ class ComplexDict(object): display=dict(default='text', choices=['text', 'json']), validate=dict(type='bool') ) - transform = ComplexDict(argument_spec, module) + transform = Entity(module, argument_spec) value = dict(command='foo') result = transform(value) print result @@ -66,31 +75,42 @@ class ComplexDict(object): * fallback - implements fallback function * choices - set of valid options * default - default value - """ - def __init__(self, attrs, module): - self._attributes = attrs + def __init__(self, module, attrs=None, args=[], keys=None, from_argspec=False): + self._attributes = attrs or {} self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]['read_from'] = arg + if keys and arg in keys: + self._attributes[arg]['key'] = True + self.attr_names = frozenset(self._attributes.keys()) - self._has_key = False + _has_key = False + for name, attr in iteritems(self._attributes): if attr.get('read_from'): + if attr['read_from'] not in self._module.argument_spec: + module.fail_json(msg='argument %s does not exist' % attr['read_from']) spec = self._module.argument_spec.get(attr['read_from']) - if not spec: - raise ValueError('argument_spec %s does not exist' % attr['read_from']) for key, value in iteritems(spec): if key not in attr: attr[key] = value if attr.get('key'): - if self._has_key: - raise ValueError('only one key value can be specified') - self._has_key = True + if _has_key: + module.fail_json(msg='only one key value can be specified') + _has_key = True attr['required'] = True - def _dict(self, value): + def serialize(self): + return self._attributes + + def to_dict(self, value): obj = {} for name, attr in iteritems(self._attributes): if attr.get('key'): @@ -99,16 +119,17 @@ class ComplexDict(object): obj[name] = attr.get('default') return obj - def __call__(self, value): + def __call__(self, value, strict=True): if not isinstance(value, dict): - value = self._dict(value) + value = self.to_dict(value) - unknown = set(value).difference(self.attr_names) - if unknown: - raise ValueError('invalid keys: %s' % ','.join(unknown)) + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown)) for name, attr in iteritems(self._attributes): - if not value.get(name): + if value.get(name) is None: value[name] = attr.get('default') if attr.get('fallback') and not value.get(name): @@ -128,24 +149,135 @@ class ComplexDict(object): continue if attr.get('required') and value.get(name) is None: - raise ValueError('missing required attribute %s' % name) + self._module.fail_json(msg='missing required attribute %s' % name) if 'choices' in attr: if value[name] not in attr['choices']: - raise ValueError('%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) + self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) if value[name] is not None: value_type = attr.get('type', 'str') type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] return value -class ComplexList(ComplexDict): - """Extends ```ComplexDict``` to handle a list of dicts """ +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [super(EntityCollection, self).__call__(self._module.params, strict)] + + if not isinstance(iterable, (list, tuple)): + module.fail_json(msg='value must be an iterable') + + return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + assert isinstance(base, dict), "`base` must be of type " + assert isinstance(comparable, dict), "`comparable` must be of type " + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + updates[key] = dict_diff(value, comparable[key]) + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_combine(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + assert isinstance(base, dict), "`base` must be of type " + assert isinstance(other, dict), "`other` must be of type " + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + combined[key] = dict_combine(value, other[key]) + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + combined[key] = list(set(chain(value, item))) + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) - def __call__(self, values): - if not isinstance(values, (list, tuple)): - raise TypeError('value must be an ordered iterable') - return [(super(ComplexList, self).__call__(v)) for v in values] + return combined diff --git a/test/units/module_utils/test_network_common.py b/test/units/module_utils/test_network_common.py new file mode 100644 index 00000000000..1d3c323a7cc --- /dev/null +++ b/test/units/module_utils/test_network_common.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# +# (c) 2017 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import (absolute_import, division) +__metaclass__ = type + +from ansible.compat.tests import unittest + +from ansible.module_utils.network_common import to_list, sort_list +from ansible.module_utils.network_common import dict_diff, dict_combine + + +class TestModuleUtilsNetworkCommon(unittest.TestCase): + + def test_to_list(self): + for scalar in ('string', 1, True, False, None): + self.assertTrue(isinstance(to_list(scalar), list)) + + for container in ([1, 2, 3], {'one': 1}): + self.assertTrue(isinstance(to_list(container), list)) + + test_list = [1, 2, 3] + self.assertNotEqual(id(test_list), id(to_list(test_list))) + + def test_sort(self): + data = [3, 1, 2] + self.assertEqual([1, 2, 3], sort_list(data)) + + string_data = '123' + self.assertEqual(string_data, sort_list(string_data)) + + def test_dict_diff(self): + base = dict(obj2=dict(), b1=True, b2=False, b3=False, + one=1, two=2, three=3, obj1=dict(key1=1, key2=2), + l1=[1, 3], l2=[1, 2, 3], l4=[4], + nested=dict(n1=dict(n2=2))) + + other = dict(b1=True, b2=False, b3=True, b4=True, + one=1, three=4, four=4, obj1=dict(key1=2), + l1=[2, 1], l2=[3, 2, 1], l3=[1], + nested=dict(n1=dict(n2=2, n3=3))) + + result = dict_diff(base, other) + + # string assertions + self.assertNotIn('one', result) + self.assertNotIn('two', result) + self.assertEqual(result['three'], 4) + self.assertEqual(result['four'], 4) + + # dict assertions + self.assertIn('obj1', result) + self.assertIn('key1', result['obj1']) + self.assertNotIn('key2', result['obj1']) + + # list assertions + self.assertEqual(result['l1'], [2, 1]) + self.assertNotIn('l2', result) + self.assertEqual(result['l3'], [1]) + self.assertNotIn('l4', result) + + # nested assertions + self.assertIn('obj1', result) + self.assertEqual(result['obj1']['key1'], 2) + self.assertNotIn('key2', result['obj1']) + + # bool assertions + self.assertNotIn('b1', result) + self.assertNotIn('b2', result) + self.assertTrue(result['b3']) + self.assertTrue(result['b4']) + + def test_dict_combine(self): + base = dict(obj2=dict(), b1=True, b2=False, b3=False, + one=1, two=2, three=3, obj1=dict(key1=1, key2=2), + l1=[1, 3], l2=[1, 2, 3], l4=[4], + nested=dict(n1=dict(n2=2))) + + other = dict(b1=True, b2=False, b3=True, b4=True, + one=1, three=4, four=4, obj1=dict(key1=2), + l1=[2, 1], l2=[3, 2, 1], l3=[1], + nested=dict(n1=dict(n2=2, n3=3))) + + result = dict_combine(base, other) + + # string assertions + self.assertIn('one', result) + self.assertIn('two', result) + self.assertEqual(result['three'], 4) + self.assertEqual(result['four'], 4) + + # dict assertions + self.assertIn('obj1', result) + self.assertIn('key1', result['obj1']) + self.assertIn('key2', result['obj1']) + + # list assertions + self.assertEqual(result['l1'], [1, 2, 3]) + self.assertIn('l2', result) + self.assertEqual(result['l3'], [1]) + self.assertIn('l4', result) + + # nested assertions + self.assertIn('obj1', result) + self.assertEqual(result['obj1']['key1'], 2) + self.assertIn('key2', result['obj1']) + + # bool assertions + self.assertIn('b1', result) + self.assertIn('b2', result) + self.assertTrue(result['b3']) + self.assertTrue(result['b4'])