diff --git a/v2/ansible/parsing/vault/__init__.py b/v2/ansible/parsing/vault/__init__.py index ddb92e4e7d3..80c48a3b69c 100644 --- a/v2/ansible/parsing/vault/__init__.py +++ b/v2/ansible/parsing/vault/__init__.py @@ -73,6 +73,7 @@ CRYPTO_UPGRADE = "ansible-vault requires a newer version of pycrypto than the on HEADER=u'$ANSIBLE_VAULT' CIPHER_WHITELIST=['AES', 'AES256'] + class VaultLib(object): def __init__(self, password): @@ -334,7 +335,7 @@ class VaultEditor(object): if os.path.isfile(filename): os.remove(filename) f = open(filename, "wb") - f.write(data) + f.write(to_bytes(data)) f.close() def shuffle_files(self, src, dest): @@ -410,7 +411,6 @@ class VaultAES(object): cipher = AES.new(key, AES.MODE_CBC, iv) full = to_bytes(b'Salted__' + salt) out_file.write(full) - print(repr(full)) finished = False while not finished: chunk = in_file.read(1024 * bs) @@ -422,10 +422,8 @@ class VaultAES(object): out_file.seek(0) enc_data = out_file.read() - #print(enc_data) tmp_data = hexlify(enc_data) - assert isinstance(tmp_data, binary_type) return tmp_data @@ -444,7 +442,6 @@ class VaultAES(object): bs = AES.block_size tmpsalt = in_file.read(bs) - print(repr(tmpsalt)) salt = tmpsalt[len('Salted__'):] key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs) cipher = AES.new(key, AES.MODE_CBC, iv) @@ -461,11 +458,15 @@ class VaultAES(object): chunk = chunk[:-padding_length] finished = True + out_file.write(chunk) + out_file.flush() # reset the stream pointer to the beginning out_file.seek(0) - new_data = to_unicode(out_file.read()) + out_data = out_file.read() + out_file.close() + new_data = to_unicode(out_data) # split out sha and verify decryption split_data = new_data.split("\n") @@ -476,7 +477,6 @@ class VaultAES(object): if this_sha != test_sha: raise errors.AnsibleError("Decryption failed") - #return out_file.read() return this_data diff --git a/v2/test/parsing/vault/test_vault_editor.py b/v2/test/parsing/vault/test_vault_editor.py index c788df54ae5..fd52ca2490e 100644 --- a/v2/test/parsing/vault/test_vault_editor.py +++ b/v2/test/parsing/vault/test_vault_editor.py @@ -21,6 +21,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type #!/usr/bin/env python +import sys import getpass import os import shutil @@ -32,6 +33,7 @@ from nose.plugins.skip import SkipTest from ansible.compat.tests import unittest from ansible.compat.tests.mock import patch +from ansible.utils.unicode import to_bytes, to_unicode from ansible import errors from ansible.parsing.vault import VaultLib @@ -88,12 +90,12 @@ class TestVaultEditor(unittest.TestCase): 'read_data', 'write_data', 'shuffle_files'] - for slot in slots: + for slot in slots: assert hasattr(v, slot), "VaultLib is missing the %s method" % slot @patch.object(VaultEditor, '_editor_shell_command') def test_create_file(self, mock_editor_shell_command): - + def sc_side_effect(filename): return ['touch', filename] mock_editor_shell_command.side_effect = sc_side_effect @@ -107,12 +109,16 @@ class TestVaultEditor(unittest.TestCase): self.assertTrue(os.path.exists(tmp_file.name)) def test_decrypt_1_0(self): - if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + """ + Skip testing decrypting 1.0 files if we don't have access to AES, KDF or + Counter, or we are running on python3 since VaultAES hasn't been backported. + """ + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3': raise SkipTest v10_file = tempfile.NamedTemporaryFile(delete=False) with v10_file as f: - f.write(v10_data) + f.write(to_bytes(v10_data)) ve = VaultEditor(None, "ansible", v10_file.name) @@ -125,13 +131,13 @@ class TestVaultEditor(unittest.TestCase): # verify decrypted content f = open(v10_file.name, "rb") - fdata = f.read() - f.close() + fdata = to_unicode(f.read()) + f.cloes() os.unlink(v10_file.name) - assert error_hit == False, "error decrypting 1.0 file" - assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() + assert error_hit == False, "error decrypting 1.0 file" + assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() def test_decrypt_1_1(self): @@ -140,7 +146,7 @@ class TestVaultEditor(unittest.TestCase): v11_file = tempfile.NamedTemporaryFile(delete=False) with v11_file as f: - f.write(v11_data) + f.write(to_bytes(v11_data)) ve = VaultEditor(None, "ansible", v11_file.name) @@ -153,28 +159,32 @@ class TestVaultEditor(unittest.TestCase): # verify decrypted content f = open(v11_file.name, "rb") - fdata = f.read() + fdata = to_unicode(f.read()) f.close() os.unlink(v11_file.name) - assert error_hit == False, "error decrypting 1.0 file" - assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() + assert error_hit == False, "error decrypting 1.0 file" + assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip() def test_rekey_migration(self): - if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2: + """ + Skip testing rekeying files if we don't have access to AES, KDF or + Counter, or we are running on python3 since VaultAES hasn't been backported. + """ + if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3': raise SkipTest v10_file = tempfile.NamedTemporaryFile(delete=False) with v10_file as f: - f.write(v10_data) + f.write(to_bytes(v10_data)) ve = VaultEditor(None, "ansible", v10_file.name) # make sure the password functions for the cipher error_hit = False - try: + try: ve.rekey_file('ansible2') except errors.AnsibleError as e: error_hit = True @@ -184,7 +194,7 @@ class TestVaultEditor(unittest.TestCase): fdata = f.read() f.close() - assert error_hit == False, "error rekeying 1.0 file to 1.1" + assert error_hit == False, "error rekeying 1.0 file to 1.1" # ensure filedata can be decrypted, is 1.1 and is AES256 vl = VaultLib("ansible2") @@ -198,7 +208,7 @@ class TestVaultEditor(unittest.TestCase): os.unlink(v10_file.name) assert vl.cipher_name == "AES256", "wrong cipher name set after rekey: %s" % vl.cipher_name - assert error_hit == False, "error decrypting migrated 1.0 file" + assert error_hit == False, "error decrypting migrated 1.0 file" assert dec_data.strip() == "foo", "incorrect decryption of rekeyed/migrated file: %s" % dec_data