recursive_diff: handle error when parameters are not dict (#74801)

Co-authored-by: Sam Doran <sdoran@redhat.com>
pull/75224/head
Abhijeet Kasurde 3 years ago committed by GitHub
parent ac151e5ad0
commit e7a3715a90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- recursive_diff - handle condition when parameters are not dict (https://github.com/ansible/ansible/issues/56249).

@ -10,6 +10,8 @@ __metaclass__ = type
import re import re
from copy import deepcopy from copy import deepcopy
from ansible.module_utils.common._collections_compat import MutableMapping
def camel_dict_to_snake_dict(camel_dict, reversible=False, ignore_list=()): def camel_dict_to_snake_dict(camel_dict, reversible=False, ignore_list=()):
""" """
@ -123,6 +125,19 @@ def dict_merge(a, b):
def recursive_diff(dict1, dict2): def recursive_diff(dict1, dict2):
"""Recursively diff two dictionaries
Raises ``TypeError`` for incorrect argument type.
:arg dict1: Dictionary to compare against.
:arg dict2: Dictionary to compare with ``dict1``.
:return: Tuple of dictionaries of differences or ``None`` if there are no differences.
"""
if not all((isinstance(item, MutableMapping) for item in (dict1, dict2))):
raise TypeError("Unable to diff 'dict1' %s and 'dict2' %s. "
"Both must be a dictionary." % (type(dict1), type(dict2)))
left = dict((k, v) for (k, v) in dict1.items() if k not in dict2) left = dict((k, v) for (k, v) in dict1.items() if k not in dict2)
right = dict((k, v) for (k, v) in dict2.items() if k not in dict1) right = dict((k, v) for (k, v) in dict2.items() if k not in dict1)
for k in (set(dict1.keys()) & set(dict2.keys())): for k in (set(dict1.keys()) & set(dict2.keys())):
@ -136,5 +151,4 @@ def recursive_diff(dict1, dict2):
right[k] = dict2[k] right[k] = dict2[k]
if left or right: if left or right:
return left, right return left, right
else: return None
return None

@ -1,26 +1,20 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# (c) 2017, Will Thames <will.thames@xvt.com.au> # Copyright: (c) 2017, Will Thames <will.thames@xvt.com.au>
# # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
from units.compat import unittest import pytest
from ansible.module_utils.common.dict_transformations import _camel_to_snake, _snake_to_camel, camel_dict_to_snake_dict, dict_merge
from ansible.module_utils.common.dict_transformations import (
_camel_to_snake,
_snake_to_camel,
camel_dict_to_snake_dict,
dict_merge,
recursive_diff,
)
EXPECTED_SNAKIFICATION = { EXPECTED_SNAKIFICATION = {
'alllower': 'alllower', 'alllower': 'alllower',
@ -42,39 +36,39 @@ EXPECTED_REVERSIBLE = {
} }
class CamelToSnakeTestCase(unittest.TestCase): class TestCaseCamelToSnake:
def test_camel_to_snake(self): def test_camel_to_snake(self):
for (k, v) in EXPECTED_SNAKIFICATION.items(): for (k, v) in EXPECTED_SNAKIFICATION.items():
self.assertEqual(_camel_to_snake(k), v) assert _camel_to_snake(k) == v
def test_reversible_camel_to_snake(self): def test_reversible_camel_to_snake(self):
for (k, v) in EXPECTED_REVERSIBLE.items(): for (k, v) in EXPECTED_REVERSIBLE.items():
self.assertEqual(_camel_to_snake(k, reversible=True), v) assert _camel_to_snake(k, reversible=True) == v
class SnakeToCamelTestCase(unittest.TestCase): class TestCaseSnakeToCamel:
def test_snake_to_camel_reversed(self): def test_snake_to_camel_reversed(self):
for (k, v) in EXPECTED_REVERSIBLE.items(): for (k, v) in EXPECTED_REVERSIBLE.items():
self.assertEqual(_snake_to_camel(v, capitalize_first=True), k) assert _snake_to_camel(v, capitalize_first=True) == k
class CamelToSnakeAndBackTestCase(unittest.TestCase): class TestCaseCamelToSnakeAndBack:
def test_camel_to_snake_and_back(self): def test_camel_to_snake_and_back(self):
for (k, v) in EXPECTED_REVERSIBLE.items(): for (k, v) in EXPECTED_REVERSIBLE.items():
self.assertEqual(_snake_to_camel(_camel_to_snake(k, reversible=True), capitalize_first=True), k) assert _snake_to_camel(_camel_to_snake(k, reversible=True), capitalize_first=True) == k
class CamelDictToSnakeDictTestCase(unittest.TestCase): class TestCaseCamelDictToSnakeDict:
def test_ignore_list(self): def test_ignore_list(self):
camel_dict = dict(Hello=dict(One='one', Two='two'), World=dict(Three='three', Four='four')) camel_dict = dict(Hello=dict(One='one', Two='two'), World=dict(Three='three', Four='four'))
snake_dict = camel_dict_to_snake_dict(camel_dict, ignore_list='World') snake_dict = camel_dict_to_snake_dict(camel_dict, ignore_list='World')
self.assertEqual(snake_dict['hello'], dict(one='one', two='two')) assert snake_dict['hello'] == dict(one='one', two='two')
self.assertEqual(snake_dict['world'], dict(Three='three', Four='four')) assert snake_dict['world'] == dict(Three='three', Four='four')
class DictMergeTestCase(unittest.TestCase): class TestCaseDictMerge:
def test_dict_merge(self): def test_dict_merge(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False, base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2), one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
@ -89,42 +83,42 @@ class DictMergeTestCase(unittest.TestCase):
result = dict_merge(base, other) result = dict_merge(base, other)
# string assertions # string assertions
self.assertTrue('one' in result) assert 'one' in result
self.assertTrue('two' in result) assert 'two' in result
self.assertEqual(result['three'], 4) assert result['three'] == 4
self.assertEqual(result['four'], 4) assert result['four'] == 4
# dict assertions # dict assertions
self.assertTrue('obj1' in result) assert 'obj1' in result
self.assertTrue('key1' in result['obj1']) assert 'key1' in result['obj1']
self.assertTrue('key2' in result['obj1']) assert 'key2' in result['obj1']
# list assertions # list assertions
# this line differs from the network_utils/common test of the function of the # this line differs from the network_utils/common test of the function of the
# same name as this method does not merge lists # same name as this method does not merge lists
self.assertEqual(result['l1'], [2, 1]) assert result['l1'], [2, 1]
self.assertTrue('l2' in result) assert 'l2' in result
self.assertEqual(result['l3'], [1]) assert result['l3'], [1]
self.assertTrue('l4' in result) assert 'l4' in result
# nested assertions # nested assertions
self.assertTrue('obj1' in result) assert 'obj1' in result
self.assertEqual(result['obj1']['key1'], 2) assert result['obj1']['key1'], 2
self.assertTrue('key2' in result['obj1']) assert 'key2' in result['obj1']
# bool assertions # bool assertions
self.assertTrue('b1' in result) assert 'b1' in result
self.assertTrue('b2' in result) assert 'b2' in result
self.assertTrue(result['b3']) assert result['b3']
self.assertTrue(result['b4']) assert result['b4']
class AzureIncidentalTestCase(unittest.TestCase): class TestCaseAzureIncidental:
def test_dict_merge_invalid_dict(self): def test_dict_merge_invalid_dict(self):
''' if b is not a dict, return b ''' ''' if b is not a dict, return b '''
res = dict_merge({}, None) res = dict_merge({}, None)
self.assertEqual(res, None) assert res is None
def test_merge_sub_dicts(self): def test_merge_sub_dicts(self):
'''merge sub dicts ''' '''merge sub dicts '''
@ -132,4 +126,28 @@ class AzureIncidentalTestCase(unittest.TestCase):
b = {'a': {'b1': 2}} b = {'a': {'b1': 2}}
c = {'a': {'a1': 1, 'b1': 2}} c = {'a': {'a1': 1, 'b1': 2}}
res = dict_merge(a, b) res = dict_merge(a, b)
self.assertEqual(res, c) assert res == c
class TestCaseRecursiveDiff:
def test_recursive_diff(self):
a = {'foo': {'bar': [{'baz': {'qux': 'ham_sandwich'}}]}}
c = {'foo': {'bar': [{'baz': {'qux': 'ham_sandwich'}}]}}
b = {'foo': {'bar': [{'baz': {'qux': 'turkey_sandwich'}}]}}
assert recursive_diff(a, b) is not None
assert len(recursive_diff(a, b)) == 2
assert recursive_diff(a, c) is None
@pytest.mark.parametrize(
'p1, p2', (
([1, 2], [2, 3]),
({1: 2}, [2, 3]),
([1, 2], {2: 3}),
({2: 3}, 'notadict'),
('notadict', {2: 3}),
)
)
def test_recursive_diff_negative(self, p1, p2):
with pytest.raises(TypeError, match="Unable to diff"):
recursive_diff(p1, p2)

Loading…
Cancel
Save