# -*- coding: utf-8 -*- # (c) 2012-2014, Michael DeHaan # (c) 2016, Toshio Kuratomi # # This file is part of Ansible # # Ansible is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # Ansible is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . from __future__ import annotations import io import os import pathlib import tempfile import typing as t from binascii import hexlify import pytest import unittest from unittest.mock import patch, MagicMock from ansible import errors from ansible.module_utils.common.text.converters import to_bytes, to_text from ansible.module_utils._internal import _messages from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.parsing import vault from ansible.parsing.vault import EncryptedString, VaultSecretsContext, AnsibleVaultError, VaultHelper from ansible._internal._templating._jinja_common import VaultExceptionMarker, TruncationMarker, Marker from ansible._internal._templating._engine import TemplateEngine, TemplateOptions from ansible._internal._templating._utils import TemplateContext from ansible._internal._datatag._tags import VaultedValue, Origin, TrustedAsTemplate from ansible.utils.collection_loader import _EncryptedStringProtocol from units.mock.loader import DictDataLoader from units.mock.vault_helper import TextVaultSecret, VaultTestHelper class TestUnhexlify(unittest.TestCase): def test(self): b_plain_data = b'some text to hexlify' b_data = hexlify(b_plain_data) res = vault._unhexlify(b_data) self.assertEqual(res, b_plain_data) def test_odd_length(self): b_data = b'123456789abcdefghijklmnopqrstuvwxyz' self.assertRaisesRegex(vault.AnsibleVaultFormatError, '.*Vault format unhexlify error.*', vault._unhexlify, b_data) def test_nonhex(self): b_data = b'6z36316566653264333665333637623064303639353237620a636366633565663263336335656532' self.assertRaisesRegex(vault.AnsibleVaultFormatError, '.*Vault format unhexlify error.*Non-hexadecimal digit found', vault._unhexlify, b_data) class TestParseVaulttext(unittest.TestCase): def test(self): vaulttext_envelope = u"""$ANSIBLE_VAULT;1.1;AES256 33363965326261303234626463623963633531343539616138316433353830356566396130353436 3562643163366231316662386565383735653432386435610a306664636137376132643732393835 63383038383730306639353234326630666539346233376330303938323639306661313032396437 6233623062366136310a633866373936313238333730653739323461656662303864663666653563 3138""" b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8') b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope) res = vault.parse_vaulttext(b_vaulttext) self.assertIsInstance(res[0], bytes) self.assertIsInstance(res[1], bytes) self.assertIsInstance(res[2], bytes) def test_non_hex(self): vaulttext_envelope = u"""$ANSIBLE_VAULT;1.1;AES256 3336396J326261303234626463623963633531343539616138316433353830356566396130353436 3562643163366231316662386565383735653432386435610a306664636137376132643732393835 63383038383730306639353234326630666539346233376330303938323639306661313032396437 6233623062366136310a633866373936313238333730653739323461656662303864663666653563 3138""" b_vaulttext_envelope = to_bytes(vaulttext_envelope, errors='strict', encoding='utf-8') b_vaulttext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext_envelope) self.assertRaisesRegex(vault.AnsibleVaultFormatError, '.*Vault format unhexlify error.*Non-hexadecimal digit found', vault.parse_vaulttext, b_vaulttext_envelope) class TestVaultSecret(unittest.TestCase): def test(self): secret = vault.VaultSecret() secret.load() self.assertIsNone(secret._bytes) def test_bytes(self): some_text = u'私はガラスを食べられます。それは私を傷つけません。' _bytes = to_bytes(some_text) secret = vault.VaultSecret(_bytes) secret.load() self.assertEqual(secret.bytes, _bytes) class TestPromptVaultSecret(unittest.TestCase): def test_empty_prompt_formats(self): secret = vault.PromptVaultSecret(vault_id='test_id', prompt_formats=[]) secret.load() self.assertIsNone(secret._bytes) @patch('ansible.parsing.vault.display.prompt', return_value='the_password') def test_prompt_formats_none(self, mock_display_prompt): secret = vault.PromptVaultSecret(vault_id='test_id') secret.load() self.assertEqual(secret._bytes, b'the_password') @patch('ansible.parsing.vault.display.prompt', return_value='the_password') def test_custom_prompt(self, mock_display_prompt): secret = vault.PromptVaultSecret(vault_id='test_id', prompt_formats=['The cow flies at midnight: ']) secret.load() self.assertEqual(secret._bytes, b'the_password') @patch('ansible.parsing.vault.display.prompt', side_effect=EOFError) def test_prompt_eoferror(self, mock_display_prompt): secret = vault.PromptVaultSecret(vault_id='test_id') self.assertRaisesRegex(vault.AnsibleVaultError, 'EOFError.*test_id', secret.load) @patch('ansible.parsing.vault.display.prompt', side_effect=['first_password', 'second_password']) def test_prompt_passwords_dont_match(self, mock_display_prompt): secret = vault.PromptVaultSecret(vault_id='test_id', prompt_formats=['Vault password: ', 'Confirm Vault password: ']) self.assertRaisesRegex(errors.AnsibleError, 'Passwords do not match', secret.load) class TestFileVaultSecret(unittest.TestCase): def setUp(self): self.vault_password = "test-vault-password" text_secret = TextVaultSecret(self.vault_password) self.vault_secrets = [('foo', text_secret)] def test(self): secret = vault.FileVaultSecret() self.assertIsNone(secret._bytes) self.assertIsNone(secret._text) def test_repr_empty(self): secret = vault.FileVaultSecret() self.assertEqual(repr(secret), "FileVaultSecret()") def test_repr(self): tmp_file = tempfile.NamedTemporaryFile(delete=False) fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) filename = tmp_file.name tmp_file.close() self.assertEqual(repr(secret), "FileVaultSecret(filename='%s')" % filename) def test_empty_bytes(self): secret = vault.FileVaultSecret() self.assertIsNone(secret.bytes) def test_file(self): password = 'some password' tmp_file = tempfile.NamedTemporaryFile(delete=False) tmp_file.write(to_bytes(password)) tmp_file.close() fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) secret.load() os.unlink(tmp_file.name) self.assertEqual(secret.bytes, to_bytes(password)) def test_file_empty(self): tmp_file = tempfile.NamedTemporaryFile(delete=False) tmp_file.write(to_bytes('')) tmp_file.close() fake_loader = DictDataLoader({tmp_file.name: ''}) secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) self.assertRaisesRegex(vault.AnsibleVaultPasswordError, 'Invalid vault password was provided from file.*%s' % tmp_file.name, secret.load) os.unlink(tmp_file.name) def test_file_encrypted(self): vault_password = "test-vault-password" text_secret = TextVaultSecret(vault_password) vault_secrets = [('foo', text_secret)] password = 'some password' # 'some password' encrypted with 'test-ansible-password' password_file_content = """$ANSIBLE_VAULT;1.1;AES256 61393863643638653437313566313632306462383837303132346434616433313438353634613762 3334363431623364386164616163326537366333353663650a663634306232363432626162353665 39623061353266373631636331643761306665343731376633623439313138396330346237653930 6432643864346136640a653364386634666461306231353765636662316335613235383565306437 3737 """ tmp_file = tempfile.NamedTemporaryFile(delete=False) tmp_file.write(to_bytes(password_file_content)) tmp_file.close() fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) fake_loader._vault.secrets = vault_secrets secret = vault.FileVaultSecret(loader=fake_loader, filename=tmp_file.name) secret.load() os.unlink(tmp_file.name) self.assertEqual(secret.bytes, to_bytes(password)) def test_file_not_a_directory(self): filename = '/dev/null/foobar' fake_loader = DictDataLoader({filename: 'sdfadf'}) secret = vault.FileVaultSecret(loader=fake_loader, filename=filename) self.assertRaisesRegex(errors.AnsibleError, '.*Could not read vault password file.*/dev/null/foobar.*Not a directory', secret.load) def test_file_not_found(self): tmp_file = tempfile.NamedTemporaryFile() filename = os.path.realpath(tmp_file.name) tmp_file.close() fake_loader = DictDataLoader({filename: 'sdfadf'}) secret = vault.FileVaultSecret(loader=fake_loader, filename=filename) self.assertRaisesRegex(errors.AnsibleError, '.*Could not read vault password file.*%s.*' % filename, secret.load) class TestScriptVaultSecret(unittest.TestCase): def test(self): secret = vault.ScriptVaultSecret() self.assertIsNone(secret._bytes) self.assertIsNone(secret._text) def _mock_popen(self, mock_popen, return_code=0, stdout=b'', stderr=b''): def communicate(): return stdout, stderr mock_popen.return_value = MagicMock(returncode=return_code) mock_popen_instance = mock_popen.return_value mock_popen_instance.communicate = communicate @patch('ansible.parsing.vault.subprocess.Popen') def test_read_file(self, mock_popen): self._mock_popen(mock_popen, stdout=b'some_password') secret = vault.ScriptVaultSecret() with patch.object(secret, 'loader') as mock_loader: mock_loader.is_executable = MagicMock(return_value=True) secret.load() @patch('ansible.parsing.vault.subprocess.Popen') def test_read_file_empty(self, mock_popen): self._mock_popen(mock_popen, stdout=b'') secret = vault.ScriptVaultSecret() with patch.object(secret, 'loader') as mock_loader: mock_loader.is_executable = MagicMock(return_value=True) self.assertRaisesRegex(vault.AnsibleVaultPasswordError, 'Invalid vault password was provided from script', secret.load) @patch('ansible.parsing.vault.subprocess.Popen') def test_read_file_os_error(self, mock_popen): self._mock_popen(mock_popen) mock_popen.side_effect = OSError('That is not an executable') secret = vault.ScriptVaultSecret() with patch.object(secret, 'loader') as mock_loader: mock_loader.is_executable = MagicMock(return_value=True) self.assertRaisesRegex(errors.AnsibleError, 'Problem running vault password script.*', secret.load) @patch('ansible.parsing.vault.subprocess.Popen') def test_read_file_not_executable(self, mock_popen): self._mock_popen(mock_popen) secret = vault.ScriptVaultSecret() with patch.object(secret, 'loader') as mock_loader: mock_loader.is_executable = MagicMock(return_value=False) self.assertRaisesRegex(vault.AnsibleVaultError, 'The vault password script .* was not executable', secret.load) @patch('ansible.parsing.vault.subprocess.Popen') def test_read_file_non_zero_return_code(self, mock_popen): stderr = b'That did not work for a random reason' rc = 37 self._mock_popen(mock_popen, return_code=rc, stderr=stderr) secret = vault.ScriptVaultSecret(filename='/dev/null/some_vault_secret') with patch.object(secret, 'loader') as mock_loader: mock_loader.is_executable = MagicMock(return_value=True) self.assertRaisesRegex(errors.AnsibleError, r'Vault password script.*returned non-zero \(%s\): %s' % (rc, stderr), secret.load) class TestScriptIsClient(unittest.TestCase): def test_randomname(self): filename = 'randomname' res = vault.script_is_client(filename) self.assertFalse(res) def test_something_dash_client(self): filename = 'something-client' res = vault.script_is_client(filename) self.assertTrue(res) def test_something_dash_client_somethingelse(self): filename = 'something-client-somethingelse' res = vault.script_is_client(filename) self.assertFalse(res) def test_something_dash_client_py(self): filename = 'something-client.py' res = vault.script_is_client(filename) self.assertTrue(res) def test_full_path_something_dash_client_py(self): filename = '/foo/bar/something-client.py' res = vault.script_is_client(filename) self.assertTrue(res) def test_full_path_something_dash_client(self): filename = '/foo/bar/something-client' res = vault.script_is_client(filename) self.assertTrue(res) def test_full_path_something_dash_client_in_dir(self): filename = '/foo/bar/something-client/but/not/filename' res = vault.script_is_client(filename) self.assertFalse(res) class TestGetFileVaultSecret(unittest.TestCase): def test_file(self): password = 'some password' tmp_file = tempfile.NamedTemporaryFile(delete=False) tmp_file.write(to_bytes(password)) tmp_file.close() fake_loader = DictDataLoader({tmp_file.name: 'sdfadf'}) secret = vault.get_file_vault_secret(filename=tmp_file.name, loader=fake_loader) secret.load() os.unlink(tmp_file.name) self.assertEqual(secret.bytes, to_bytes(password)) def test_file_not_a_directory(self): filename = '/dev/null/foobar' fake_loader = DictDataLoader({filename: 'sdfadf'}) self.assertRaisesRegex(errors.AnsibleError, '.*The vault password file %s was not found.*' % filename, vault.get_file_vault_secret, filename=filename, loader=fake_loader) def test_file_not_found(self): tmp_file = tempfile.NamedTemporaryFile() filename = os.path.realpath(tmp_file.name) tmp_file.close() fake_loader = DictDataLoader({filename: 'sdfadf'}) self.assertRaisesRegex(errors.AnsibleError, '.*The vault password file %s was not found.*' % filename, vault.get_file_vault_secret, filename=filename, loader=fake_loader) class TestVaultIsEncrypted(unittest.TestCase): def test_bytes_not_encrypted(self): b_data = b"foobar" self.assertFalse(vault.is_encrypted(b_data)) def test_bytes_encrypted(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") self.assertTrue(vault.is_encrypted(b_data)) def test_text_not_encrypted(self): b_data = to_text(b"foobar") self.assertFalse(vault.is_encrypted(b_data)) def test_text_encrypted(self): b_data = to_text(b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")) self.assertTrue(vault.is_encrypted(b_data)) def test_invalid_text_not_ascii(self): data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " self.assertFalse(vault.is_encrypted(data)) def test_invalid_bytes_not_ascii(self): data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " b_data = to_bytes(data, encoding='utf-8') self.assertFalse(vault.is_encrypted(b_data)) class TestVaultIsEncryptedFile(unittest.TestCase): def test_binary_file_handle_not_encrypted(self): b_data = b"foobar" b_data_fo = io.BytesIO(b_data) self.assertFalse(vault.is_encrypted_file(b_data_fo)) def test_text_file_handle_not_encrypted(self): data = u"foobar" data_fo = io.StringIO(data) self.assertFalse(vault.is_encrypted_file(data_fo)) def test_binary_file_handle_encrypted(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible") b_data_fo = io.BytesIO(b_data) self.assertTrue(vault.is_encrypted_file(b_data_fo)) def test_text_file_handle_encrypted(self): data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % to_text(hexlify(b"ansible")) data_fo = io.StringIO(data) self.assertTrue(vault.is_encrypted_file(data_fo)) def test_binary_file_handle_invalid(self): data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " b_data = to_bytes(data) b_data_fo = io.BytesIO(b_data) self.assertFalse(vault.is_encrypted_file(b_data_fo, count=-1)) def test_text_file_handle_invalid(self): data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % u"ァ ア ィ イ ゥ ウ ェ エ ォ オ カ ガ キ ギ ク グ ケ " data_fo = io.StringIO(data) self.assertFalse(vault.is_encrypted_file(data_fo, count=-1)) def test_file_already_read_from_finds_header(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") b_data_fo = io.BytesIO(b_data) b_data_fo.read(42) # Arbitrary number self.assertTrue(vault.is_encrypted_file(b_data_fo)) def test_file_already_read_from_saves_file_pos(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") b_data_fo = io.BytesIO(b_data) b_data_fo.read(69) # Arbitrary number vault.is_encrypted_file(b_data_fo) self.assertEqual(b_data_fo.tell(), 69) def test_file_with_offset(self): b_data = b"JUNK$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") b_data_fo = io.BytesIO(b_data) self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4)) def test_file_with_count(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") vault_length = len(b_data) b_data = b_data + u'ァ ア'.encode('utf-8') b_data_fo = io.BytesIO(b_data) self.assertTrue(vault.is_encrypted_file(b_data_fo, count=vault_length)) def test_file_with_offset_and_count(self): b_data = b"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible\ntesting\nfile pos") vault_length = len(b_data) b_data = b'JUNK' + b_data + u'ァ ア'.encode('utf-8') b_data_fo = io.BytesIO(b_data) self.assertTrue(vault.is_encrypted_file(b_data_fo, start_pos=4, count=vault_length)) @pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY, reason="Skipping cryptography tests because cryptography is not installed") class TestVaultCipherAes256(unittest.TestCase): def setUp(self): self.vault_cipher = vault.VaultAES256() def test(self): self.assertIsInstance(self.vault_cipher, vault.VaultAES256) # TODO: tag these as slow tests def test_create_key_cryptography(self): b_password = b'hunter42' b_salt = os.urandom(32) b_key_cryptography = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) self.assertIsInstance(b_key_cryptography, bytes) def test_create_key_known_cryptography(self): b_password = b'hunter42' # A fixed salt b_salt = b'q' * 32 # q is the most random letter. b_key_1 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) self.assertIsInstance(b_key_1, bytes) # verify we get the same answer # we could potentially run a few iterations of this and time it to see if it's roughly constant time # and or that it exceeds some minimal time, but that would likely cause unreliable fails, esp in CI b_key_2 = self.vault_cipher._create_key_cryptography(b_password, b_salt, key_length=32, iv_length=16) self.assertIsInstance(b_key_2, bytes) self.assertEqual(b_key_1, b_key_2) def test_is_equal_is_equal(self): self.assertTrue(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwxyz')) def test_is_equal_unequal_length(self): self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'abcdefghijklmnopqrstuvwx and sometimes y')) def test_is_equal_not_equal(self): self.assertFalse(self.vault_cipher._is_equal(b'abcdefghijklmnopqrstuvwxyz', b'AbcdefghijKlmnopQrstuvwxZ')) def test_is_equal_empty(self): self.assertTrue(self.vault_cipher._is_equal(b'', b'')) def test_is_equal_non_ascii_equal(self): utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。') self.assertTrue(self.vault_cipher._is_equal(utf8_data, utf8_data)) def test_is_equal_non_ascii_unequal(self): utf8_data = to_bytes(u'私はガラスを食べられます。それは私を傷つけません。') utf8_data2 = to_bytes(u'Pot să mănânc sticlă și ea nu mă rănește.') # Test for the len optimization path self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data2)) # Test for the slower, char by char comparison path self.assertFalse(self.vault_cipher._is_equal(utf8_data, utf8_data[:-1] + b'P')) def test_is_equal_non_bytes(self): """ Anything not a byte string should raise a TypeError """ self.assertRaises(TypeError, self.vault_cipher._is_equal, u"One fish", b"two fish") self.assertRaises(TypeError, self.vault_cipher._is_equal, b"One fish", u"two fish") self.assertRaises(TypeError, self.vault_cipher._is_equal, 1, b"red fish") self.assertRaises(TypeError, self.vault_cipher._is_equal, b"blue fish", 2) class TestMatchSecrets(unittest.TestCase): def test_empty_tuple(self): secrets = [tuple()] vault_ids = ['vault_id_1'] self.assertRaises(ValueError, vault.match_secrets, secrets, vault_ids) def test_empty_secrets(self): matches = vault.match_secrets([], ['vault_id_1']) self.assertEqual(matches, []) def test_single_match(self): secret = TextVaultSecret('password') matches = vault.match_secrets([('default', secret)], ['default']) self.assertEqual(matches, [('default', secret)]) def test_no_matches(self): secret = TextVaultSecret('password') matches = vault.match_secrets([('default', secret)], ['not_default']) self.assertEqual(matches, []) def test_multiple_matches(self): secrets = [('vault_id1', TextVaultSecret('password1')), ('vault_id2', TextVaultSecret('password2')), ('vault_id1', TextVaultSecret('password3')), ('vault_id4', TextVaultSecret('password4'))] vault_ids = ['vault_id1', 'vault_id4'] matches = vault.match_secrets(secrets, vault_ids) self.assertEqual(len(matches), 3) expected = [('vault_id1', TextVaultSecret('password1')), ('vault_id1', TextVaultSecret('password3')), ('vault_id4', TextVaultSecret('password4'))] self.assertEqual([x for x, y in matches], [a for a, b in expected]) @pytest.mark.skipif(not vault.HAS_CRYPTOGRAPHY, reason="Skipping cryptography tests because cryptography is not installed") class TestVaultLib(unittest.TestCase): def setUp(self): self.vault_password = "test-vault-password" text_secret = TextVaultSecret(self.vault_password) self.vault_secrets = [('default', text_secret), ('test_id', text_secret)] self.v = vault.VaultLib(self.vault_secrets) def _vault_secrets_from_password(self, vault_id, password): return [(vault_id, TextVaultSecret(password))] def test_encrypt(self): plaintext = u'Some text to encrypt in a café' b_vaulttext = self.v.encrypt(plaintext) self.assertIsInstance(b_vaulttext, bytes) b_header = b'$ANSIBLE_VAULT;1.1;AES256\n' self.assertEqual(b_vaulttext[:len(b_header)], b_header) def test_encrypt_vault_id(self): plaintext = u'Some text to encrypt in a café' b_vaulttext = self.v.encrypt(plaintext, vault_id='test_id') self.assertIsInstance(b_vaulttext, bytes) b_header = b'$ANSIBLE_VAULT;1.2;AES256;test_id\n' self.assertEqual(b_vaulttext[:len(b_header)], b_header) def test_encrypt_bytes(self): plaintext = to_bytes(u'Some text to encrypt in a café') b_vaulttext = self.v.encrypt(plaintext) self.assertIsInstance(b_vaulttext, bytes) b_header = b'$ANSIBLE_VAULT;1.1;AES256\n' self.assertEqual(b_vaulttext[:len(b_header)], b_header) def test_encrypt_no_secret_empty_secrets(self): vault_secrets = [] v = vault.VaultLib(vault_secrets) plaintext = u'Some text to encrypt in a café' self.assertRaisesRegex(vault.AnsibleVaultError, '.*A vault password must be specified to encrypt data.*', v.encrypt, plaintext) def test_format_vaulttext_envelope(self): cipher_name = "TEST" b_ciphertext = b"ansible" b_vaulttext = vault.format_vaulttext_envelope(b_ciphertext, cipher_name, version=self.v.b_version, vault_id='default') b_lines = b_vaulttext.split(b'\n') self.assertGreater(len(b_lines), 1, msg="failed to properly add header") b_header = b_lines[0] # self.assertTrue(b_header.endswith(b';TEST'), msg="header does not end with cipher name") b_header_parts = b_header.split(b';') self.assertEqual(len(b_header_parts), 4, msg="header has the wrong number of parts") self.assertEqual(b_header_parts[0], b'$ANSIBLE_VAULT', msg="header does not start with $ANSIBLE_VAULT") self.assertEqual(b_header_parts[1], self.v.b_version, msg="header version is incorrect") self.assertEqual(b_header_parts[2], b'TEST', msg="header does not end with cipher name") # And just to verify, lets parse the results and compare b_ciphertext2, b_version2, cipher_name2, vault_id2 = \ vault.parse_vaulttext_envelope(b_vaulttext) self.assertEqual(b_ciphertext, b_ciphertext2) self.assertEqual(self.v.b_version, b_version2) self.assertEqual(cipher_name, cipher_name2) self.assertEqual('default', vault_id2) def test_parse_vaulttext_envelope(self): b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\nansible" b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) b_lines = b_ciphertext.split(b'\n') self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header") self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set") self.assertEqual(b_version, b"9.9", msg="version was not properly set") def test_parse_vaulttext_envelope_crlf(self): b_vaulttext = b"$ANSIBLE_VAULT;9.9;TEST\r\nansible" b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) b_lines = b_ciphertext.split(b'\n') self.assertEqual(b_lines[0], b"ansible", msg="Payload was not properly split from the header") self.assertEqual(cipher_name, u'TEST', msg="cipher name was not properly set") self.assertEqual(b_version, b"9.9", msg="version was not properly set") def test_encrypt_decrypt_aes256(self): self.v.cipher_name = u'AES256' plaintext = u"foobar" b_vaulttext = self.v.encrypt(plaintext) b_plaintext = self.v.decrypt(b_vaulttext) self.assertNotEqual(b_vaulttext, b"foobar", msg="encryption failed") self.assertEqual(b_plaintext, b"foobar", msg="decryption failed") def test_encrypt_decrypt_aes256_none_secrets(self): vault_secrets = self._vault_secrets_from_password('default', 'ansible') v = vault.VaultLib(vault_secrets) plaintext = u"foobar" b_vaulttext = v.encrypt(plaintext) # VaultLib will default to empty {} if secrets is None v_none = vault.VaultLib(None) # so set secrets None explicitly v_none.secrets = None self.assertRaisesRegex(vault.AnsibleVaultError, '.*A vault password must be specified to decrypt data.*', v_none.decrypt, b_vaulttext) def test_encrypt_decrypt_aes256_empty_secrets(self): vault_secrets = self._vault_secrets_from_password('default', 'ansible') v = vault.VaultLib(vault_secrets) plaintext = u"foobar" b_vaulttext = v.encrypt(plaintext) vault_secrets_empty = [] v_none = vault.VaultLib(vault_secrets_empty) self.assertRaisesRegex(vault.AnsibleVaultError, '.*Attempting to decrypt but no vault secrets found.*', v_none.decrypt, b_vaulttext) def test_encrypt_decrypt_aes256_multiple_secrets_all_wrong(self): plaintext = u'Some text to encrypt in a café' b_vaulttext = self.v.encrypt(plaintext) vault_secrets = [('default', TextVaultSecret('another-wrong-password')), ('wrong-password', TextVaultSecret('wrong-password'))] v_multi = vault.VaultLib(vault_secrets) with self.assertRaisesRegex(errors.AnsibleError, '.*Decryption failed.*'): v_multi.decrypt(b_vaulttext) def test_encrypt_decrypt_aes256_multiple_secrets_one_valid(self): plaintext = u'Some text to encrypt in a café' b_vaulttext = self.v.encrypt(plaintext) correct_secret = TextVaultSecret(self.vault_password) wrong_secret = TextVaultSecret('wrong-password') vault_secrets = [('default', wrong_secret), ('corect_secret', correct_secret), ('wrong_secret', wrong_secret)] v_multi = vault.VaultLib(vault_secrets) b_plaintext = v_multi.decrypt(b_vaulttext) self.assertNotEqual(b_vaulttext, to_bytes(plaintext), msg="encryption failed") self.assertEqual(b_plaintext, to_bytes(plaintext), msg="decryption failed") def test_encrypt_decrypt_aes256_existing_vault(self): self.v.cipher_name = u'AES256' b_orig_plaintext = b"Setec Astronomy" vaulttext = u"""$ANSIBLE_VAULT;1.1;AES256 33363965326261303234626463623963633531343539616138316433353830356566396130353436 3562643163366231316662386565383735653432386435610a306664636137376132643732393835 63383038383730306639353234326630666539346233376330303938323639306661313032396437 6233623062366136310a633866373936313238333730653739323461656662303864663666653563 3138""" b_plaintext = self.v.decrypt(vaulttext) self.assertEqual(b_plaintext, b_plaintext, msg="decryption failed") b_vaulttext = to_bytes(vaulttext, encoding='ascii', errors='strict') b_plaintext = self.v.decrypt(b_vaulttext) self.assertEqual(b_plaintext, b_orig_plaintext, msg="decryption failed") def test_decrypt_and_get_vault_id(self): b_expected_plaintext = to_bytes('foo bar\n') vaulttext = """$ANSIBLE_VAULT;1.2;AES256;ansible_devel 65616435333934613466373335363332373764363365633035303466643439313864663837393234 3330656363343637313962633731333237313636633534630a386264363438363362326132363239 39363166646664346264383934393935653933316263333838386362633534326664646166663736 6462303664383765650a356637643633366663643566353036303162386237336233393065393164 6264""" vault_secrets = self._vault_secrets_from_password('ansible_devel', 'ansible') v = vault.VaultLib(vault_secrets) b_vaulttext = to_bytes(vaulttext) b_plaintext, vault_id_used, vault_secret_used = v.decrypt_and_get_vault_id(b_vaulttext) self.assertEqual(b_expected_plaintext, b_plaintext) self.assertEqual(vault_id_used, 'ansible_devel') self.assertEqual(vault_secret_used.text, 'ansible') def test_decrypt_non_default_1_2(self): b_expected_plaintext = to_bytes('foo bar\n') vaulttext = """$ANSIBLE_VAULT;1.2;AES256;ansible_devel 65616435333934613466373335363332373764363365633035303466643439313864663837393234 3330656363343637313962633731333237313636633534630a386264363438363362326132363239 39363166646664346264383934393935653933316263333838386362633534326664646166663736 6462303664383765650a356637643633366663643566353036303162386237336233393065393164 6264""" vault_secrets = self._vault_secrets_from_password('default', 'ansible') v = vault.VaultLib(vault_secrets) b_vaulttext = to_bytes(vaulttext) b_plaintext = v.decrypt(b_vaulttext) self.assertEqual(b_expected_plaintext, b_plaintext) b_ciphertext, b_version, cipher_name, vault_id = vault.parse_vaulttext_envelope(b_vaulttext) self.assertEqual('ansible_devel', vault_id) self.assertEqual(b'1.2', b_version) def test_decrypt_decrypted(self): plaintext = u"ansible" self.assertRaises(errors.AnsibleError, self.v.decrypt, plaintext) b_plaintext = b"ansible" self.assertRaises(errors.AnsibleError, self.v.decrypt, b_plaintext) def test_cipher_not_set(self): plaintext = u"ansible" self.v.encrypt(plaintext) self.assertEqual(self.v.cipher_name, "AES256") def test_verify_encrypted_string_methods(): """Verify all relevant methods on `str` are implemented on `EncryptedString`.""" str_methods_to_not_implement = { '__dir__', '__getattribute__', # generated by Python, identical to the one on `str` in Python 3.13+, but different on earlier versions '__getnewargs__', '__getstate__', '__reduce__', '__reduce_ex__', } str_methods = set(dir('')) encrypted_string_methods = set(name for name in dir(EncryptedString) if name in str_methods and getattr(EncryptedString, name) is not getattr(str, name)) missing_methods = str_methods - encrypted_string_methods - str_methods_to_not_implement assert not missing_methods def test_vault_secrets_context_required(_zap_vault_secrets_context) -> None: with pytest.raises(ReferenceError, match=f"A required {VaultSecretsContext.__name__} context is not active."): VaultSecretsContext.current() def test_vault_secrets_context_optional(_zap_vault_secrets_context) -> None: assert VaultSecretsContext.current(optional=True) is None def test_vault_secrets_context_current(_zap_vault_secrets_context) -> None: ctx = VaultSecretsContext([]) VaultSecretsContext.initialize(ctx) assert VaultSecretsContext.current() is ctx def test_vault_secrets_context_already_initialized(_zap_vault_secrets_context) -> None: VaultSecretsContext.initialize(VaultSecretsContext([])) with pytest.raises(RuntimeError, match=f"The {VaultSecretsContext.__name__} context is already initialized."): VaultSecretsContext.initialize(VaultSecretsContext([])) def test_encrypted_string_unmanaged_access(_vault_secrets_context: VaultTestHelper) -> None: """ Validates that unmanaged access to an `EncryptedString`: * properly decrypts and caches the value when secrets are available * propagates Origin and VaultedValue tags """ plaintext = 'i am plaintext' encrypted_string = _vault_secrets_context.make_encrypted_string(plaintext) origin = Origin.get_required_tag(encrypted_string) vaulted_value_tag = VaultedValue(ciphertext=VaultHelper.get_ciphertext(encrypted_string, with_tags=False)) res1 = str(encrypted_string) res2 = str(encrypted_string) assert res1 == res2 == plaintext assert res1 is res2 # ensure the result is cached assert Origin.get_required_tag(res1) is origin assert VaultedValue.get_required_tag(res1).ciphertext == vaulted_value_tag.ciphertext def test_encrypted_string_unmanaged_access_fail(_vault_secrets_context: VaultTestHelper) -> None: """Validates that unmanaged access to an `EncryptedString` fails with AnsibleVaultError when secrets are not available.""" encrypted_string = _vault_secrets_context.make_encrypted_string("i am plaintext") VaultSecretsContext.current().secrets = [] with pytest.raises(AnsibleVaultError): str(encrypted_string) @pytest.mark.parametrize( "value, conversion_func", ( ("fourty-two", str), (42, int), (42.42, float), )) def test_encrypted_string_conversion_methods(value: t.Any, conversion_func: t.Callable, _vault_secrets_context: VaultTestHelper): """Ensure that `EncryptedString` dunder conversion methods decrypt and pass through correctly.""" encrypted_string = _vault_secrets_context.make_encrypted_string(str(value)) converted = conversion_func(encrypted_string) assert converted == value def test_radd(_vault_secrets_context: VaultTestHelper) -> None: """Ensure that the __radd__ dunder method decrypts and passes through.""" assert "plain string " + _vault_secrets_context.make_encrypted_string("secret string") == "plain string secret string" def make_marker(marker_type: type[Marker], *args, **kwargs): """Utility function to create marker values under a TemplateContext.""" with TemplateContext(template_value="blah", templar=TemplateEngine(), options=TemplateOptions.DEFAULT): return marker_type(*args, **kwargs) origin = Origin(path="/test") @pytest.mark.parametrize("value, expected_ciphertext", ( (origin.tag(EncryptedString(ciphertext="ciphertext")), "ciphertext"), (origin.tag(VaultedValue(ciphertext="ciphertext").tag("something")), "ciphertext"), (make_marker(VaultExceptionMarker, ciphertext=origin.tag("ciphertext"), event=_messages.Event(msg="")), "ciphertext"), (make_marker(TruncationMarker), None), ("not vaulted", None), )) def test_vaulthelper_get_ciphertext(value: t.Any, expected_ciphertext: str | None) -> None: """Validate `get_ciphertext` helper responses and tag preservation behavior.""" expected_tags = {origin} if expected_ciphertext is not None else set() tagged_ciphertext = VaultHelper.get_ciphertext(value, with_tags=True) untagged_ciphertext = VaultHelper.get_ciphertext(value, with_tags=False) assert untagged_ciphertext == expected_ciphertext assert tagged_ciphertext == expected_ciphertext assert AnsibleTagHelper.tags(tagged_ciphertext) == expected_tags assert not AnsibleTagHelper.tags(untagged_ciphertext) @pytest.mark.parametrize("expression, expected_expression", ( ("os.path.join(ed, efn)", "str(temp_file)"), ("os.path.exists(ef)", "True"), ("os.path.dirname(ef)", "str(temp_dir)"), ("os.path.isdir(ed)", "True"), ("os.path.basename(ef)", "temp_file.name"), ("os.listdir(ed)", "[temp_file.name]"), ("open(ef).read()", "'Ansible'"), )) def test_encrypted_string_path_fspath(_vault_secrets_context: VaultTestHelper, expression: str, expected_expression: str) -> None: """Ensure that `EncryptedString` works with `PathLike` duck-typing consumers.""" with tempfile.TemporaryDirectory() as temp_dir_path: temp_dir = pathlib.Path(temp_dir_path) temp_file = (pathlib.Path(temp_dir) / 'temp_file') temp_file.write_text('Ansible') expression_locals = dict( temp_file=temp_file, temp_dir=temp_dir, ed=_vault_secrets_context.make_encrypted_string(str(temp_dir)), edn=_vault_secrets_context.make_encrypted_string(temp_dir.name), ef=_vault_secrets_context.make_encrypted_string(str(temp_file)), efn=_vault_secrets_context.make_encrypted_string(temp_file.name), ) expected = eval(expected_expression, globals(), expression_locals) result = eval(expression, globals(), expression_locals) assert result == expected def test_protocol_conformance(_vault_secrets_context: VaultTestHelper) -> None: """Verify that the `_EncryptedStringProtocol` defined by the collection loader is implemented.""" assert isinstance(_vault_secrets_context.make_encrypted_string("hey"), _EncryptedStringProtocol) def test_encrypted_string_cannot_be_trusted() -> None: """Verify that `EncryptedString` cannot be trusted for templating.""" es = EncryptedString(ciphertext='x') still_not_trusted = TrustedAsTemplate().tag(es) assert not TrustedAsTemplate.is_tagged_on(still_not_trusted)