diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py index 95d12ffb6bc..00baf8463b2 100644 --- a/lib/ansible/module_utils/common/collections.py +++ b/lib/ansible/module_utils/common/collections.py @@ -15,6 +15,18 @@ def is_string(seq): return isinstance(seq, (text_type, binary_type)) +def is_iterable(seq, include_strings=False): + """Identify whether the input is an iterable.""" + if not include_strings and is_string(seq): + return False + + try: + iter(seq) + return True + except TypeError: + return False + + def is_sequence(seq, include_strings=False): """Identify whether the input is a sequence. diff --git a/test/units/module_utils/common/collections.py b/test/units/module_utils/common/collections.py index 8bdddfcf99c..cf7be6183d0 100644 --- a/test/units/module_utils/common/collections.py +++ b/test/units/module_utils/common/collections.py @@ -9,7 +9,7 @@ __metaclass__ = type import pytest from ansible.module_utils.common._collections_compat import Sequence -from ansible.module_utils.common.collections import is_sequence +from ansible.module_utils.common.collections import is_iterable, is_sequence class SeqStub: @@ -24,6 +24,16 @@ class SeqStub: Sequence.register(SeqStub) +class IteratorStub: + def __next__(self): + raise StopIteration + + +class IterableStub: + def __iter__(self): + return IteratorStub() + + TEST_STRINGS = u'he', u'Україна', u'Česká republika' TEST_STRINGS = TEST_STRINGS + tuple(s.encode('utf-8') for s in TEST_STRINGS) @@ -65,3 +75,28 @@ def test_sequence_string_types_with_strings(string_input): def test_sequence_string_types_without_strings(string_input): """Test that ``is_sequence`` can separate string and non-string.""" assert not is_sequence(string_input, include_strings=False) + + +@pytest.mark.parametrize( + 'seq', + ([], (), {}, set(), frozenset(), IterableStub()), +) +def test_iterable_positive(seq): + assert is_iterable(seq) + + +@pytest.mark.parametrize( + 'seq', (IteratorStub(), object(), 5, 9.) +) +def test_iterable_negative(seq): + assert not is_iterable(seq) + + +@pytest.mark.parametrize('string_input', TEST_STRINGS) +def test_iterable_including_strings(string_input): + assert is_iterable(string_input, include_strings=True) + + +@pytest.mark.parametrize('string_input', TEST_STRINGS) +def test_iterable_excluding_strings(string_input): + assert not is_iterable(string_input, include_strings=False)