Improve code coverage of unit tests (#81119)

- Remove unused code
- Remove unnecessary code
- Ignore coverage for unreachable code
- Use previously unused code to increase coverage
pull/81121/head
Matt Clay 2 years ago committed by GitHub
parent da2cd157f1
commit 82b5544b09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,21 +5,10 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
import pytest
from ansible.module_utils.common.parameters import _list_deprecations from ansible.module_utils.common.parameters import _list_deprecations
@pytest.fixture
def params():
return {
'name': 'bob',
'dest': '/etc/hosts',
'state': 'present',
'value': 5,
}
def test_list_deprecations(): def test_list_deprecations():
argument_spec = { argument_spec = {
'old': {'type': 'str', 'removed_in_version': '2.5'}, 'old': {'type': 'str', 'removed_in_version': '2.5'},

@ -8,7 +8,6 @@ __metaclass__ = type
import pytest import pytest
from ansible.module_utils.six import Iterator
from ansible.module_utils.six.moves.collections_abc import Sequence from ansible.module_utils.six.moves.collections_abc import Sequence
from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence from ansible.module_utils.common.collections import ImmutableDict, is_iterable, is_sequence
@ -25,16 +24,6 @@ class SeqStub:
Sequence.register(SeqStub) Sequence.register(SeqStub)
class IteratorStub(Iterator):
def __next__(self):
raise StopIteration
class IterableStub:
def __iter__(self):
return IteratorStub()
class FakeAnsibleVaultEncryptedUnicode(Sequence): class FakeAnsibleVaultEncryptedUnicode(Sequence):
__ENCRYPTED__ = True __ENCRYPTED__ = True
@ -42,10 +31,10 @@ class FakeAnsibleVaultEncryptedUnicode(Sequence):
self.data = data self.data = data
def __getitem__(self, index): def __getitem__(self, index):
return self.data[index] raise NotImplementedError() # pragma: nocover
def __len__(self): def __len__(self):
return len(self.data) raise NotImplementedError() # pragma: nocover
TEST_STRINGS = u'he', u'Україна', u'Česká republika' TEST_STRINGS = u'he', u'Україна', u'Česká republika'
@ -93,14 +82,14 @@ def test_sequence_string_types_without_strings(string_input):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'seq', 'seq',
([], (), {}, set(), frozenset(), IterableStub()), ([], (), {}, set(), frozenset()),
) )
def test_iterable_positive(seq): def test_iterable_positive(seq):
assert is_iterable(seq) assert is_iterable(seq)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'seq', (IteratorStub(), object(), 5, 9.) 'seq', (object(), 5, 9.)
) )
def test_iterable_negative(seq): def test_iterable_negative(seq):
assert not is_iterable(seq) assert not is_iterable(seq)

@ -20,12 +20,6 @@ class timezone(tzinfo):
def utcoffset(self, dt): def utcoffset(self, dt):
return self._offset return self._offset
def dst(self, dt):
return timedelta(0)
def tzname(self, dt):
return None
@pytest.mark.parametrize( @pytest.mark.parametrize(
'test_input,expected', 'test_input,expected',

@ -12,11 +12,6 @@ from ansible.module_utils.common.text.converters import to_native
from ansible.module_utils.common.validation import check_missing_parameters from ansible.module_utils.common.validation import check_missing_parameters
@pytest.fixture
def arguments_terms():
return {"path": ""}
def test_check_missing_parameters(): def test_check_missing_parameters():
assert check_missing_parameters([], {}) == [] assert check_missing_parameters([], {}) == []

@ -8,24 +8,15 @@ import json
import pytest import pytest
from ansible.module_utils.six import string_types
from ansible.module_utils.common.text.converters import to_bytes from ansible.module_utils.common.text.converters import to_bytes
from ansible.module_utils.six.moves.collections_abc import MutableMapping
@pytest.fixture @pytest.fixture
def patch_ansible_module(request, mocker): def patch_ansible_module(request, mocker):
if isinstance(request.param, string_types): request.param = {'ANSIBLE_MODULE_ARGS': request.param}
args = request.param request.param['ANSIBLE_MODULE_ARGS']['_ansible_remote_tmp'] = '/tmp'
elif isinstance(request.param, MutableMapping): request.param['ANSIBLE_MODULE_ARGS']['_ansible_keep_remote_files'] = False
if 'ANSIBLE_MODULE_ARGS' not in request.param:
request.param = {'ANSIBLE_MODULE_ARGS': request.param} args = json.dumps(request.param)
if '_ansible_remote_tmp' not in request.param['ANSIBLE_MODULE_ARGS']:
request.param['ANSIBLE_MODULE_ARGS']['_ansible_remote_tmp'] = '/tmp'
if '_ansible_keep_remote_files' not in request.param['ANSIBLE_MODULE_ARGS']:
request.param['ANSIBLE_MODULE_ARGS']['_ansible_keep_remote_files'] = False
args = json.dumps(request.param)
else:
raise Exception('Malformed data to the patch_ansible_module pytest fixture')
mocker.patch('ansible.module_utils.basic._ANSIBLE_ARGS', to_bytes(args)) mocker.patch('ansible.module_utils.basic._ANSIBLE_ARGS', to_bytes(args))

@ -2,20 +2,13 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import collections import collections
import sys
from units.compat.mock import Mock from units.compat.mock import Mock
from units.compat import unittest from units.compat import unittest
try: from ansible.modules.apt import (
from ansible.modules.apt import ( expand_pkgspec_from_fnmatches,
expand_pkgspec_from_fnmatches, )
)
except Exception:
# Need some more module_utils work (porting urls.py) before we can test
# modules. So don't error out in this case.
if sys.version_info[0] >= 3:
pass
class AptExpandPkgspecTestCase(unittest.TestCase): class AptExpandPkgspecTestCase(unittest.TestCase):

@ -43,12 +43,9 @@ class TestHostname(ModuleTestCase):
classname = "%sStrategy" % prefix classname = "%sStrategy" % prefix
cls = getattr(hostname, classname, None) cls = getattr(hostname, classname, None)
if cls is None: assert cls is not None
self.assertFalse(
cls is None, "%s is None, should be a subclass" % classname self.assertTrue(issubclass(cls, hostname.BaseStrategy))
)
else:
self.assertTrue(issubclass(cls, hostname.BaseStrategy))
class TestRedhatStrategy(ModuleTestCase): class TestRedhatStrategy(ModuleTestCase):

@ -8,20 +8,6 @@ import pytest
from ansible.modules.unarchive import ZipArchive, TgzArchive from ansible.modules.unarchive import ZipArchive, TgzArchive
class AnsibleModuleExit(Exception):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class ExitJson(AnsibleModuleExit):
pass
class FailJson(AnsibleModuleExit):
pass
@pytest.fixture @pytest.fixture
def fake_ansible_module(): def fake_ansible_module():
return FakeAnsibleModule() return FakeAnsibleModule()
@ -32,12 +18,6 @@ class FakeAnsibleModule:
self.params = {} self.params = {}
self.tmpdir = None self.tmpdir = None
def exit_json(self, *args, **kwargs):
raise ExitJson(*args, **kwargs)
def fail_json(self, *args, **kwargs):
raise FailJson(*args, **kwargs)
class TestCaseZipArchive: class TestCaseZipArchive:
@pytest.mark.parametrize( @pytest.mark.parametrize(

@ -10,10 +10,8 @@ from ansible.module_utils.common.text.converters import to_bytes
def set_module_args(args): def set_module_args(args):
if '_ansible_remote_tmp' not in args: args['_ansible_remote_tmp'] = '/tmp'
args['_ansible_remote_tmp'] = '/tmp' args['_ansible_keep_remote_files'] = False
if '_ansible_keep_remote_files' not in args:
args['_ansible_keep_remote_files'] = False
args = json.dumps({'ANSIBLE_MODULE_ARGS': args}) args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args) basic._ANSIBLE_ARGS = to_bytes(args)
@ -28,8 +26,6 @@ class AnsibleFailJson(Exception):
def exit_json(*args, **kwargs): def exit_json(*args, **kwargs):
if 'changed' not in kwargs:
kwargs['changed'] = False
raise AnsibleExitJson(kwargs) raise AnsibleExitJson(kwargs)

@ -109,7 +109,11 @@ class TestAnsibleJSONEncoder:
def __len__(self): def __len__(self):
return len(self.__dict__) return len(self.__dict__)
return M(request.param) mapping = M(request.param)
assert isinstance(len(mapping), int) # ensure coverage of __len__
return mapping
@pytest.fixture @pytest.fixture
def ansible_json_encoder(self): def ansible_json_encoder(self):

@ -26,7 +26,6 @@ from unittest.mock import patch, mock_open
from ansible.errors import AnsibleParserError, yaml_strings, AnsibleFileNotFound from ansible.errors import AnsibleParserError, yaml_strings, AnsibleFileNotFound
from ansible.parsing.vault import AnsibleVaultError from ansible.parsing.vault import AnsibleVaultError
from ansible.module_utils.common.text.converters import to_text from ansible.module_utils.common.text.converters import to_text
from ansible.module_utils.six import PY3
from units.mock.vault_helper import TextVaultSecret from units.mock.vault_helper import TextVaultSecret
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
@ -229,11 +228,7 @@ class TestDataLoaderWithVault(unittest.TestCase):
3135306561356164310a343937653834643433343734653137383339323330626437313562306630 3135306561356164310a343937653834643433343734653137383339323330626437313562306630
3035 3035
""" """
if PY3:
builtins_name = 'builtins'
else:
builtins_name = '__builtin__'
with patch(builtins_name + '.open', mock_open(read_data=vaulted_data.encode('utf-8'))): with patch('builtins.open', mock_open(read_data=vaulted_data.encode('utf-8'))):
output = self._loader.load_from_file('dummy_vault.txt') output = self._loader.load_from_file('dummy_vault.txt')
self.assertEqual(output, dict(foo='bar')) self.assertEqual(output, dict(foo='bar'))

@ -21,7 +21,6 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import binascii
import io import io
import os import os
import tempfile import tempfile
@ -606,9 +605,6 @@ class TestVaultLib(unittest.TestCase):
('test_id', text_secret)] ('test_id', text_secret)]
self.v = vault.VaultLib(self.vault_secrets) self.v = vault.VaultLib(self.vault_secrets)
def _vault_secrets(self, vault_id, secret):
return [(vault_id, secret)]
def _vault_secrets_from_password(self, vault_id, password): def _vault_secrets_from_password(self, vault_id, password):
return [(vault_id, TextVaultSecret(password))] return [(vault_id, TextVaultSecret(password))]
@ -779,43 +775,6 @@ class TestVaultLib(unittest.TestCase):
b_plaintext = self.v.decrypt(b_vaulttext) b_plaintext = self.v.decrypt(b_vaulttext)
self.assertEqual(b_plaintext, b_orig_plaintext, msg="decryption failed") self.assertEqual(b_plaintext, b_orig_plaintext, msg="decryption failed")
# FIXME This test isn't working quite yet.
@pytest.mark.skip(reason='This test is not ready yet')
def test_encrypt_decrypt_aes256_bad_hmac(self):
self.v.cipher_name = 'AES256'
# plaintext = "Setec Astronomy"
enc_data = '''$ANSIBLE_VAULT;1.1;AES256
33363965326261303234626463623963633531343539616138316433353830356566396130353436
3562643163366231316662386565383735653432386435610a306664636137376132643732393835
63383038383730306639353234326630666539346233376330303938323639306661313032396437
6233623062366136310a633866373936313238333730653739323461656662303864663666653563
3138'''
b_data = to_bytes(enc_data, errors='strict', encoding='utf-8')
b_data = self.v._split_header(b_data)
unhex_data = binascii.unhexlify(b_data)
lines = unhex_data.splitlines()
# line 0 is salt, line 1 is hmac, line 2+ is ciphertext
b_salt = lines[0]
b_hmac = lines[1]
b_ciphertext_data = b'\n'.join(lines[2:])
b_ciphertext = binascii.unhexlify(b_ciphertext_data)
# b_orig_ciphertext = b_ciphertext[:]
# now muck with the text
# b_munged_ciphertext = b_ciphertext[:10] + b'\x00' + b_ciphertext[11:]
# b_munged_ciphertext = b_ciphertext
# assert b_orig_ciphertext != b_munged_ciphertext
b_ciphertext_data = binascii.hexlify(b_ciphertext)
b_payload = b'\n'.join([b_salt, b_hmac, b_ciphertext_data])
# reformat
b_invalid_ciphertext = self.v._format_output(b_payload)
# assert we throw an error
self.v.decrypt(b_invalid_ciphertext)
def test_decrypt_and_get_vault_id(self): def test_decrypt_and_get_vault_id(self):
b_expected_plaintext = to_bytes('foo bar\n') b_expected_plaintext = to_bytes('foo bar\n')
vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel vaulttext = '''$ANSIBLE_VAULT;1.2;AES256;ansible_devel

@ -33,7 +33,6 @@ from ansible import errors
from ansible.parsing import vault from ansible.parsing import vault
from ansible.parsing.vault import VaultLib, VaultEditor, match_encrypt_secret from ansible.parsing.vault import VaultLib, VaultEditor, match_encrypt_secret
from ansible.module_utils.six import PY3
from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils.common.text.converters import to_bytes, to_text
from units.mock.vault_helper import TextVaultSecret from units.mock.vault_helper import TextVaultSecret
@ -88,11 +87,10 @@ class TestVaultEditor(unittest.TestCase):
suffix = '_ansible_unit_test_%s_' % (self.__class__.__name__) suffix = '_ansible_unit_test_%s_' % (self.__class__.__name__)
return tempfile.mkdtemp(suffix=suffix) return tempfile.mkdtemp(suffix=suffix)
def _create_file(self, test_dir, name, content=None, symlink=False): def _create_file(self, test_dir, name, content, symlink=False):
file_path = os.path.join(test_dir, name) file_path = os.path.join(test_dir, name)
with open(file_path, 'wb') as opened_file: with open(file_path, 'wb') as opened_file:
if content: opened_file.write(content)
opened_file.write(content)
return file_path return file_path
def _vault_editor(self, vault_secrets=None): def _vault_editor(self, vault_secrets=None):
@ -117,11 +115,8 @@ class TestVaultEditor(unittest.TestCase):
def test_stdin_binary(self): def test_stdin_binary(self):
stdin_data = '\0' stdin_data = '\0'
if PY3: fake_stream = StringIO(stdin_data)
fake_stream = StringIO(stdin_data) fake_stream.buffer = BytesIO(to_bytes(stdin_data))
fake_stream.buffer = BytesIO(to_bytes(stdin_data))
else:
fake_stream = BytesIO(to_bytes(stdin_data))
with patch('sys.stdin', fake_stream): with patch('sys.stdin', fake_stream):
ve = self._vault_editor() ve = self._vault_editor()
@ -166,7 +161,7 @@ class TestVaultEditor(unittest.TestCase):
self.assertNotEqual(src_file_contents, b_ciphertext, self.assertNotEqual(src_file_contents, b_ciphertext,
'b_ciphertext should be encrypted and not equal to src_contents') 'b_ciphertext should be encrypted and not equal to src_contents')
def _faux_editor(self, editor_args, new_src_contents=None): def _faux_editor(self, editor_args, new_src_contents):
if editor_args[0] == 'shred': if editor_args[0] == 'shred':
return return
@ -174,8 +169,7 @@ class TestVaultEditor(unittest.TestCase):
# simulate the tmp file being editted # simulate the tmp file being editted
with open(tmp_path, 'wb') as tmp_file: with open(tmp_path, 'wb') as tmp_file:
if new_src_contents: tmp_file.write(new_src_contents)
tmp_file.write(new_src_contents)
def _faux_command(self, tmp_path): def _faux_command(self, tmp_path):
pass pass
@ -416,13 +410,6 @@ class TestVaultEditor(unittest.TestCase):
src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents) src_file_path = self._create_file(self._test_dir, 'src_file', content=src_contents)
new_src_contents = to_bytes("The info is different now.")
def faux_editor(editor_args):
self._faux_editor(editor_args, new_src_contents)
mock_sp_call.side_effect = faux_editor
ve = self._vault_editor() ve = self._vault_editor()
self.assertRaisesRegex(errors.AnsibleError, self.assertRaisesRegex(errors.AnsibleError,
'input is not vault encrypted data', 'input is not vault encrypted data',
@ -476,11 +463,7 @@ class TestVaultEditor(unittest.TestCase):
ve = self._vault_editor(self._secrets("ansible")) ve = self._vault_editor(self._secrets("ansible"))
# make sure the password functions for the cipher # make sure the password functions for the cipher
error_hit = False ve.decrypt_file(v11_file.name)
try:
ve.decrypt_file(v11_file.name)
except errors.AnsibleError:
error_hit = True
# verify decrypted content # verify decrypted content
with open(v11_file.name, "rb") as f: with open(v11_file.name, "rb") as f:
@ -488,7 +471,6 @@ class TestVaultEditor(unittest.TestCase):
os.unlink(v11_file.name) os.unlink(v11_file.name)
assert error_hit is False, "error decrypting 1.1 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.1 file: %s" % fdata.strip() assert fdata.strip() == "foo", "incorrect decryption of 1.1 file: %s" % fdata.strip()
def test_real_path_dash(self): def test_real_path_dash(self):

@ -19,7 +19,6 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import io import io
import yaml
from jinja2.exceptions import UndefinedError from jinja2.exceptions import UndefinedError
@ -27,7 +26,6 @@ from units.compat import unittest
from ansible.parsing import vault from ansible.parsing import vault
from ansible.parsing.yaml import dumper, objects from ansible.parsing.yaml import dumper, objects
from ansible.parsing.yaml.loader import AnsibleLoader from ansible.parsing.yaml.loader import AnsibleLoader
from ansible.module_utils.six import PY2
from ansible.template import AnsibleUndefined from ansible.template import AnsibleUndefined
from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes from ansible.utils.unsafe_proxy import AnsibleUnsafeText, AnsibleUnsafeBytes
@ -78,20 +76,6 @@ class TestAnsibleDumper(unittest.TestCase, YamlTestUtils):
data_from_yaml = loader.get_single_data() data_from_yaml = loader.get_single_data()
result = b_text result = b_text
if PY2:
# https://pyyaml.org/wiki/PyYAMLDocumentation#string-conversion-python-2-only
# pyyaml on Python 2 can return either unicode or bytes when given byte strings.
# We normalize that to always return unicode on Python2 as that's right most of the
# time. However, this means byte strings can round trip through yaml on Python3 but
# not on Python2. To make this code work the same on Python2 and Python3 (we want
# the Python3 behaviour) we need to change the methods in Ansible to:
# (1) Let byte strings pass through yaml without being converted on Python2
# (2) Convert byte strings to text strings before being given to pyyaml (Without this,
# strings would end up as byte strings most of the time which would mostly be wrong)
# In practice, we mostly read bytes in from files and then pass that to pyyaml, for which
# the present behavior is correct.
# This is a workaround for the current behavior.
result = u'tr\xe9ma'
self.assertEqual(result, data_from_yaml) self.assertEqual(result, data_from_yaml)
@ -108,10 +92,7 @@ class TestAnsibleDumper(unittest.TestCase, YamlTestUtils):
self.assertEqual(u_text, data_from_yaml) self.assertEqual(u_text, data_from_yaml)
def test_vars_with_sources(self): def test_vars_with_sources(self):
try: self._dump_string(VarsWithSources(), dumper=self.dumper)
self._dump_string(VarsWithSources(), dumper=self.dumper)
except yaml.representer.RepresenterError:
self.fail("Dump VarsWithSources raised RepresenterError unexpectedly!")
def test_undefined(self): def test_undefined(self):
undefined_object = AnsibleUndefined() undefined_object = AnsibleUndefined()

@ -105,11 +105,6 @@ class TestAnsibleVaultEncryptedUnicode(unittest.TestCase, YamlTestUtils):
id_secret = vault.match_encrypt_secret(self.good_vault_secrets) id_secret = vault.match_encrypt_secret(self.good_vault_secrets)
return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault, secret=id_secret[1]) return objects.AnsibleVaultEncryptedUnicode.from_plaintext(seq, vault=self.vault, secret=id_secret[1])
def _from_ciphertext(self, ciphertext):
avu = objects.AnsibleVaultEncryptedUnicode(ciphertext)
avu.vault = self.vault
return avu
def test_empty_init(self): def test_empty_init(self):
self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode) self.assertRaises(TypeError, objects.AnsibleVaultEncryptedUnicode)

Loading…
Cancel
Save