Ansible vault: a framework for encrypting any playbook or var file.

pull/6058/head
James Tanner 10 years ago
parent 30611eaac5
commit 427b8dc78d

@ -62,6 +62,8 @@ def main(args):
check_opts=True,
diff_opts=True
)
#parser.add_option('--vault-password', dest="vault_password",
# help="password for vault encrypted files")
parser.add_option('-e', '--extra-vars', dest="extra_vars", action="append",
help="set additional variables as key=value or YAML/JSON", default=[])
parser.add_option('-t', '--tags', dest='tags', default='all',
@ -100,12 +102,13 @@ def main(args):
su_pass = None
if not options.listhosts and not options.syntax and not options.listtasks:
options.ask_pass = options.ask_pass or C.DEFAULT_ASK_PASS
options.ask_vault_pass = options.ask_vault_pass or C.DEFAULT_ASK_VAULT_PASS
# Never ask for an SSH password when we run with local connection
if options.connection == "local":
options.ask_pass = False
options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS
options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS
(sshpass, sudopass, su_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass)
(sshpass, sudopass, su_pass, vault_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass, ask_vault_pass=options.ask_vault_pass)
options.sudo_user = options.sudo_user or C.DEFAULT_SUDO_USER
options.su_user = options.su_user or C.DEFAULT_SU_USER
@ -170,7 +173,8 @@ def main(args):
diff=options.diff,
su=options.su,
su_pass=su_pass,
su_user=options.su_user
su_user=options.su_user,
vault_password=vault_pass
)
if options.listhosts or options.listtasks or options.syntax:

@ -0,0 +1,187 @@
#!/usr/bin/env python
# (c) 2014, James Tanner <tanner.jc@gmail.com>
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# ansible-pull is a script that runs ansible in local mode
# after checking out a playbooks directory from source repo. There is an
# example playbook to bootstrap this script in the examples/ dir which
# installs ansible and sets it up to run on cron.
import sys
import traceback
from ansible import utils
from ansible import errors
from ansible.utils.vault import *
from ansible.utils.vault import Vault
from optparse import OptionParser
#-------------------------------------------------------------------------------------
# Utility functions for parsing actions/options
#-------------------------------------------------------------------------------------
VALID_ACTIONS = ("create", "decrypt", "edit", "encrypt", "rekey")
def build_option_parser(action):
"""
Builds an option parser object based on the action
the user wants to execute.
"""
usage = "usage: %%prog [%s] [--help] [options] file_name" % "|".join(VALID_ACTIONS)
epilog = "\nSee '%s <command> --help' for more information on a specific command.\n\n" % os.path.basename(sys.argv[0])
OptionParser.format_epilog = lambda self, formatter: self.epilog
parser = OptionParser(usage=usage, epilog=epilog)
if not action:
parser.print_help()
sys.exit()
# options for all actions
#parser.add_option('-p', '--password', help="encryption key")
#parser.add_option('-c', '--cipher', dest='cipher', default="AES", help="cipher to use")
parser.add_option('-d', '--debug', dest='debug', action="store_true", help="debug")
# options specific to actions
if action == "create":
parser.set_usage("usage: %prog create [options] file_name")
elif action == "decrypt":
parser.set_usage("usage: %prog decrypt [options] file_name")
elif action == "edit":
parser.set_usage("usage: %prog edit [options] file_name")
elif action == "encrypt":
parser.set_usage("usage: %prog encrypt [options] file_name")
elif action == "rekey":
parser.set_usage("usage: %prog rekey [options] file_name")
# done, return the parser
return parser
def get_action(args):
"""
Get the action the user wants to execute from the
sys argv list.
"""
for i in range(0,len(args)):
arg = args[i]
if arg in VALID_ACTIONS:
del args[i]
return arg
return None
def get_opt(options, k, defval=""):
"""
Returns an option from an Optparse values instance.
"""
try:
data = getattr(options, k)
except:
return defval
if k == "roles_path":
if os.pathsep in data:
data = data.split(os.pathsep)[0]
return data
#-------------------------------------------------------------------------------------
# Command functions
#-------------------------------------------------------------------------------------
def _get_vault(filename, options, password):
this_vault = Vault()
this_vault.filename = filename
this_vault.vault_password = password
this_vault.password = password
return this_vault
def execute_create(args, options, parser):
if len(args) > 1:
raise errors.AnsibleError("create does not accept more than one filename")
password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, confirm_vault=True)
this_vault = _get_vault(args[0], options, password)
if not hasattr(options, 'cipher'):
this_vault.cipher = 'AES'
this_vault.create()
def execute_decrypt(args, options, parser):
password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.decrypt()
print "Decryption successful"
def execute_edit(args, options, parser):
if len(args) > 1:
raise errors.AnsibleError("create does not accept more than one filename")
password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.edit()
def execute_encrypt(args, options, parser):
password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, confirm_vault=True)
for f in args:
this_vault = _get_vault(f, options, password)
if not hasattr(options, 'cipher'):
this_vault.cipher = 'AES'
this_vault.encrypt()
print "Encryption successful"
def execute_rekey(args, options, parser):
password, new_password = utils.ask_vaultpasswords(ask_vault_pass=True, ask_new_vault_pass=True, confirm_new=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.rekey(new_password)
print "Rekey successful"
#-------------------------------------------------------------------------------------
# MAIN
#-------------------------------------------------------------------------------------
def main():
action = get_action(sys.argv)
parser = build_option_parser(action)
(options, args) = parser.parse_args()
# execute the desired action
try:
fn = globals()["execute_%s" % action]
fn(args, options, parser)
except Exception, err:
if options.debug:
print traceback.format_exc()
print "ERROR:",err
sys.exit(1)
if __name__ == "__main__":
main()

@ -117,6 +117,7 @@ DEFAULT_PRIVATE_KEY_FILE = shell_expand_path(get_config(p, DEFAULTS, 'private_k
DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root')
DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True)
DEFAULT_REMOTE_PORT = get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', None, integer=True)
DEFAULT_ASK_VAULT_PASS = get_config(p, DEFAULTS, 'ask_vault_pass', 'ANSIBLE_ASK_VAULT_PASS', False, boolean=True)
DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'smart')
DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True)
DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}')
@ -172,4 +173,5 @@ DEFAULT_SUDO_PASS = None
DEFAULT_REMOTE_PASS = None
DEFAULT_SUBSET = None
DEFAULT_SU_PASS = None
VAULT_VERSION_MIN = 1.0
VAULT_VERSION_MAX = 1.0

@ -347,19 +347,19 @@ class Inventory(object):
raise Exception("group not found: %s" % groupname)
return group.get_variables()
def get_variables(self, hostname):
def get_variables(self, hostname, vault_password=None):
if hostname not in self._vars_per_host:
self._vars_per_host[hostname] = self._get_variables(hostname)
self._vars_per_host[hostname] = self._get_variables(hostname, vault_password=vault_password)
return self._vars_per_host[hostname]
def _get_variables(self, hostname):
def _get_variables(self, hostname, vault_password=None):
host = self.get_host(hostname)
if host is None:
raise errors.AnsibleError("host not found: %s" % hostname)
vars = {}
vars_results = [ plugin.run(host) for plugin in self._vars_plugins ]
vars_results = [ plugin.run(host, vault_password=vault_password) for plugin in self._vars_plugins ]
for updated in vars_results:
if updated is not None:
vars.update(updated)

@ -23,7 +23,7 @@ from ansible import errors
from ansible import utils
import ansible.constants as C
def _load_vars(basepath, results):
def _load_vars(basepath, results, vault_password=None):
"""
Load variables from any potential yaml filename combinations of basepath,
returning result.
@ -35,7 +35,7 @@ def _load_vars(basepath, results):
found_paths = []
for path in paths_to_check:
found, results = _load_vars_from_path(path, results)
found, results = _load_vars_from_path(path, results, vault_password=vault_password)
if found:
found_paths.append(path)
@ -49,7 +49,7 @@ def _load_vars(basepath, results):
return results
def _load_vars_from_path(path, results):
def _load_vars_from_path(path, results, vault_password=None):
"""
Robustly access the file at path and load variables, carefully reporting
errors in a friendly/informative way.
@ -90,7 +90,7 @@ def _load_vars_from_path(path, results):
# regular file
elif stat.S_ISREG(pathstat.st_mode):
data = utils.parse_yaml_from_file(path)
data = utils.parse_yaml_from_file(path, vault_password=vault_password)
if type(data) != dict:
raise errors.AnsibleError(
"%s must be stored as a dictionary/hash" % path)
@ -143,7 +143,7 @@ class VarsModule(object):
self.inventory = inventory
def run(self, host):
def run(self, host, vault_password=None):
""" main body of the plugin, does actual loading """
@ -183,11 +183,11 @@ class VarsModule(object):
# load vars in dir/group_vars/name_of_group
for group in groups:
base_path = os.path.join(basedir, "group_vars/%s" % group)
results = _load_vars(base_path, results)
results = _load_vars(base_path, results, vault_password=vault_password)
# same for hostvars in dir/host_vars/name_of_host
base_path = os.path.join(basedir, "host_vars/%s" % host.name)
results = _load_vars(base_path, results)
results = _load_vars(base_path, results, vault_password=vault_password)
# all done, results is a dictionary of variables for this particular host.
return results

@ -72,6 +72,7 @@ class PlayBook(object):
su = False,
su_user = False,
su_pass = False,
vault_password = False,
):
"""
@ -138,6 +139,7 @@ class PlayBook(object):
self.su = su
self.su_user = su_user
self.su_pass = su_pass
self.vault_password = vault_password
self.callbacks.playbook = self
self.runner_callbacks.playbook = self
@ -172,7 +174,7 @@ class PlayBook(object):
run top level error checking on playbooks and allow them to include other playbooks.
'''
playbook_data = utils.parse_yaml_from_file(path)
playbook_data = utils.parse_yaml_from_file(path, vault_password=self.vault_password)
accumulated_plays = []
play_basedirs = []
@ -242,7 +244,7 @@ class PlayBook(object):
# loop through all patterns and run them
self.callbacks.on_start()
for (play_ds, play_basedir) in zip(self.playbook, self.play_basedirs):
play = Play(self, play_ds, play_basedir)
play = Play(self, play_ds, play_basedir, vault_password=self.vault_password)
assert play is not None
matched_tags, unmatched_tags = play.compare_tags(self.only_tags)
@ -352,6 +354,7 @@ class PlayBook(object):
su=task.su,
su_user=task.su_user,
su_pass=task.su_pass,
vault_pass = self.vault_password,
run_hosts=hosts,
no_log=task.no_log,
)
@ -504,6 +507,7 @@ class PlayBook(object):
su=play.su,
su_user=play.su_user,
su_pass=self.su_pass,
vault_pass=self.vault_password,
transport=play.transport,
is_playbook=True,
module_vars=play.vars,
@ -569,9 +573,8 @@ class PlayBook(object):
self._do_setup_step(play)
# now with that data, handle contentional variable file imports!
all_hosts = self._trim_unavailable_hosts(play._play_hosts)
play.update_vars_files(all_hosts)
play.update_vars_files(all_hosts, vault_password=self.vault_password)
hosts_count = len(all_hosts)
serialized_batch = []

@ -34,7 +34,7 @@ class Play(object):
'handlers', 'remote_user', 'remote_port', 'included_roles', 'accelerate',
'accelerate_port', 'accelerate_ipv6', 'sudo', 'sudo_user', 'transport', 'playbook',
'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks',
'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user'
'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user', 'vault_password'
]
# to catch typos and so forth -- these are userland names
@ -44,12 +44,12 @@ class Play(object):
'tasks', 'handlers', 'remote_user', 'user', 'port', 'include', 'accelerate', 'accelerate_port', 'accelerate_ipv6',
'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial',
'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage',
'su', 'su_user'
'su', 'su_user', 'vault_password'
]
# *************************************************
def __init__(self, playbook, ds, basedir):
def __init__(self, playbook, ds, basedir, vault_password=None):
''' constructor loads from a play datastructure '''
for x in ds.keys():
@ -64,6 +64,7 @@ class Play(object):
self.basedir = basedir
self.roles = ds.get('roles', None)
self.tags = ds.get('tags', None)
self.vault_password = vault_password
if self.tags is None:
self.tags = []
@ -88,6 +89,7 @@ class Play(object):
self.vars_files = ds.get('vars_files', [])
if not isinstance(self.vars_files, list):
raise errors.AnsibleError('vars_files must be a list')
self._update_vars_files_for_host(None)
# template everything to be efficient, but do not pre-mature template
@ -124,6 +126,7 @@ class Play(object):
self.max_fail_pct = int(ds.get('max_fail_percentage', 100))
self.su = ds.get('su', self.playbook.su)
self.su_user = ds.get('su_user', self.playbook.su_user)
#self.vault_password = vault_password
# Fail out if user specifies a sudo param with a su param in a given play
if (ds.get('sudo') or ds.get('sudo_user')) and (ds.get('su') or ds.get('su_user')):
@ -540,7 +543,7 @@ class Play(object):
dirname = os.path.dirname(original_file)
include_file = template(dirname, tokens[0], mv)
include_filename = utils.path_dwim(dirname, include_file)
data = utils.parse_yaml_from_file(include_filename)
data = utils.parse_yaml_from_file(include_filename, vault_password=self.vault_password)
if 'role_name' in x and data is not None:
for x in data:
if 'include' in x:
@ -652,12 +655,12 @@ class Play(object):
# *************************************************
def update_vars_files(self, hosts):
def update_vars_files(self, hosts, vault_password=None):
''' calculate vars_files, which requires that setup runs first so ansible facts can be mixed in '''
# now loop through all the hosts...
for h in hosts:
self._update_vars_files_for_host(h)
self._update_vars_files_for_host(h, vault_password=vault_password)
# *************************************************
@ -689,14 +692,14 @@ class Play(object):
# *************************************************
def _update_vars_files_for_host(self, host):
def _update_vars_files_for_host(self, host, vault_password=None):
if type(self.vars_files) != list:
self.vars_files = [ self.vars_files ]
if host is not None:
inject = {}
inject.update(self.playbook.inventory.get_variables(host))
inject.update(self.playbook.inventory.get_variables(host, vault_password=vault_password))
inject.update(self.playbook.SETUP_CACHE[host])
for filename in self.vars_files:
@ -747,7 +750,7 @@ class Play(object):
filename4 = utils.path_dwim(self.basedir, filename3)
if self._has_vars_in(filename4):
continue
new_vars = utils.parse_yaml_from_file(filename4)
new_vars = utils.parse_yaml_from_file(filename4, vault_password=self.vault_password)
if new_vars:
if type(new_vars) != dict:
raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename4, type(new_vars)))

@ -144,6 +144,7 @@ class Runner(object):
su=False, # Are we running our command via su?
su_user=None, # User to su to when running command, ex: 'root'
su_pass=C.DEFAULT_SU_PASS,
vault_pass=None,
run_hosts=None, # an optional list of pre-calculated hosts to run on
no_log=False, # option to enable/disable logging for a given task
):
@ -197,6 +198,7 @@ class Runner(object):
self.su_user_var = su_user
self.su_user = None
self.su_pass = su_pass
self.vault_pass = vault_pass
self.no_log = no_log
if self.transport == 'smart':
@ -534,7 +536,7 @@ class Runner(object):
def _executor_internal(self, host, new_stdin):
''' executes any module one or more times '''
host_variables = self.inventory.get_variables(host)
host_variables = self.inventory.get_variables(host, vault_password=self.vault_pass)
host_connection = host_variables.get('ansible_connection', self.transport)
if host_connection in [ 'paramiko', 'paramiko_alt', 'ssh', 'ssh_old', 'accelerate' ]:
port = host_variables.get('ansible_ssh_port', self.remote_port)

@ -43,6 +43,8 @@ import getpass
import sys
import textwrap
import vault
VERBOSITY=0
# list of all deprecation messages to prevent duplicate display
@ -494,14 +496,22 @@ Should be written as:
raise errors.AnsibleYAMLValidationFailed(msg)
def parse_yaml_from_file(path):
def parse_yaml_from_file(path, vault_password=None):
''' convert a yaml file to a data structure '''
data = None
#VAULT
if vault.is_encrypted(path):
data = vault.decrypt(path, vault_password)
else:
try:
data = open(path).read()
except IOError:
raise errors.AnsibleError("file could not read: %s" % path)
try:
data = file(path).read()
return parse_yaml(data)
except IOError:
raise errors.AnsibleError("file not found: %s" % path)
except yaml.YAMLError, exc:
process_yaml_error(exc, data, path)
@ -693,6 +703,8 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
help='ask for sudo password')
parser.add_option('--ask-su-pass', default=False, dest='ask_su_pass',
action='store_true', help='ask for su password')
parser.add_option('--ask-vault-pass', default=False, dest='ask_vault_pass',
action='store_true', help='ask for vault password')
parser.add_option('--list-hosts', dest='listhosts', action='store_true',
help='outputs a list of matching hosts; does not execute anything else')
parser.add_option('-M', '--module-path', dest='module_path',
@ -751,10 +763,34 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
return parser
def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False):
def ask_vaultpasswords(ask_vault_pass=False, ask_new_vault_pass=False, confirm_vault=False, confirm_new=False):
vault_pass = None
new_vault_pass = None
if ask_vault_pass:
vault_pass = getpass.getpass(prompt="Vault password: ")
if ask_vault_pass and confirm_vault:
vault_pass2 = getpass.getpass(prompt="Retype Vault password: ")
if vault_pass != vault_pass2:
raise errors.AnsibleError("Passwords do not match")
if ask_new_vault_pass:
new_vault_pass = getpass.getpass(prompt="New Vault password: ")
if ask_new_vault_pass and confirm_new:
new_vault_pass2 = getpass.getpass(prompt="Retype New Vault password: ")
if new_vault_pass != new_vault_pass2:
raise errors.AnsibleError("Passwords do not match")
return vault_pass, new_vault_pass
def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False, ask_vault_pass=False):
sshpass = None
sudopass = None
su_pass = None
vault_pass = None
sudo_prompt = "sudo password: "
su_prompt = "su password: "
@ -770,7 +806,10 @@ def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False):
if ask_su_pass:
su_pass = getpass.getpass(prompt=su_prompt)
return (sshpass, sudopass, su_pass)
if ask_vault_pass:
vault_pass = getpass.getpass(prompt="Vault password: ")
return (sshpass, sudopass, su_pass, vault_pass)
def do_encrypt(result, encrypt, salt_size=None, salt=None):
if PASSLIB_AVAILABLE:

@ -0,0 +1,450 @@
# (c) 2014, James Tanner <tanner.jc@gmail.com>
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# ansible-pull is a script that runs ansible in local mode
# after checking out a playbooks directory from source repo. There is an
# example playbook to bootstrap this script in the examples/ dir which
# installs ansible and sets it up to run on cron.
import os
import shutil
import tempfile
from io import BytesIO
from subprocess import call
from ansible import errors
from hashlib import sha256
from hashlib import md5
from binascii import hexlify
from binascii import unhexlify
from ansible import constants as C
# AES IMPORTS
try:
from Crypto.Cipher import AES as AES_
HAS_AES = True
except ImportError:
HAS_AES = False
HEADER='$ANSIBLE_VAULT'
def is_encrypted(filename):
'''
Check a file for the encrypted header and return True or False
The first line should start with the header
defined by the global HEADER. If true, we
assume this is a properly encrypted file.
'''
# read first line of the file
with open(filename) as f:
head = f.next()
if head.startswith(HEADER):
return True
else:
return False
def decrypt(filename, password):
'''
Return a decrypted string of the contents in an encrypted file
This is used by the yaml loading code in ansible
to automatically determine the encryption type
and return a plaintext string of the unencrypted
data.
'''
if password is None:
raise errors.AnsibleError("A vault password must be specified to decrypt %s" % filename)
V = Vault(filename=filename, vault_password=password)
return_data = V._decrypt_to_string()
if not V._verify_decryption(return_data):
raise errors.AnsibleError("Decryption of %s failed" % filename)
this_sha, return_data = V._strip_sha(return_data)
return return_data.strip()
class Vault(object):
def __init__(self, filename=None, cipher=None, vault_password=None):
self.filename = filename
self.vault_password = vault_password
self.cipher = cipher
self.version = '1.0'
###############
# PUBLIC
###############
def eval_header(self):
""" Read first line of the file and parse header """
# read first line
with open(self.filename) as f:
#head=[f.next() for x in xrange(1)]
head = f.next()
this_version = None
this_cipher = None
# split segments
if len(head.split(';')) == 3:
this_version = head.split(';')[1].strip()
this_cipher = head.split(';')[2].strip()
else:
raise errors.AnsibleError("%s has an invalid header" % self.filename)
# validate acceptable version
this_version = float(this_version)
if this_version < C.VAULT_VERSION_MIN or this_version > C.VAULT_VERSION_MAX:
raise errors.AnsibleError("%s must have a version between %s and %s " % (self.filename,
C.VAULT_VERSION_MIN,
C.VAULT_VERSION_MAX))
# set properties
self.cipher = this_cipher
self.version = this_version
def create(self):
""" create a new encrypted file """
if os.path.isfile(self.filename):
raise errors.AnsibleError("%s exists, please use 'edit' instead" % self.filename)
# drop the user into vim on file
EDITOR = os.environ.get('EDITOR','vim')
call([EDITOR, self.filename])
self.encrypt()
def decrypt(self):
""" unencrypt a file inplace """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# set cipher based on file header
self.eval_header()
# decrypt it
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# write back to original file
f = open(self.filename, "wb")
f.write(clean_data)
f.close()
def edit(self, filename=None, password=None, cipher=None, version=None):
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
#decrypt to string
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# rewrite file without sha
_, in_path = tempfile.mkstemp()
f = open(in_path, "wb")
tmpdata = f.write(clean_data)
f.close()
# drop the user into vim on the unencrypted tmp file
EDITOR = os.environ.get('EDITOR','vim')
call([EDITOR, in_path])
f = open(in_path, "rb")
tmpdata = f.read()
f.close()
self._string_to_encrypted_file(tmpdata, self.filename)
def encrypt(self):
""" encrypt a file inplace """
if is_encrypted(self.filename):
raise errors.AnsibleError("%s is already encrypted" % self.filename)
#self.eval_header()
self.__load_cipher()
# read data
f = open(self.filename, "rb")
tmpdata = f.read()
f.close()
self._string_to_encrypted_file(tmpdata, self.filename)
def rekey(self, newpassword):
""" unencrypt file then encrypt with new password """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# unencrypt to string with old password
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# set password
self.vault_password = newpassword
self._string_to_encrypted_file(clean_data, self.filename)
###############
# PRIVATE
###############
def __load_cipher(self):
"""
Load a cipher class by it's name
This is a lightweight "plugin" implementation to allow
for future support of other cipher types
"""
whitelist = ['AES']
if self.cipher in whitelist:
self.cipher_obj = None
if self.cipher in globals():
this_cipher = globals()[self.cipher]
self.cipher_obj = this_cipher()
else:
raise errors.AnsibleError("%s cipher could not be loaded" % self.cipher)
else:
raise errors.AnsibleError("%s is not an allowed encryption cipher" % self.cipher)
def _decrypt_to_string(self):
""" decrypt file to string """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# figure out what this is
self.eval_header()
self.__load_cipher()
# strip out header and unhex the file
clean_stream = self._dirty_file_to_clean_file(self.filename)
# reset pointer
clean_stream.seek(0)
# create a byte stream to hold unencrypted
dst = BytesIO()
# decrypt from src stream to dst stream
self.cipher_obj.decrypt(clean_stream, dst, self.vault_password)
# read data from the unencrypted stream
data = dst.read()
return data
def _dirty_file_to_clean_file(self, dirty_filename):
""" Strip out headers from a file, unhex and write to new file"""
_, in_path = tempfile.mkstemp()
#_, out_path = tempfile.mkstemp()
# strip header from data, write rest to tmp file
f = open(dirty_filename, "rb")
tmpdata = f.readlines()
f.close()
tmpheader = tmpdata[0].strip()
tmpdata = ''.join(tmpdata[1:])
# strip out newline, join, unhex
tmpdata = [ x.strip() for x in tmpdata ]
tmpdata = unhexlify(''.join(tmpdata))
# create and return stream
clean_stream = BytesIO(tmpdata)
return clean_stream
def _clean_stream_to_dirty_stream(self, clean_stream):
# combine header and hexlified encrypted data in 80 char columns
clean_stream.seek(0)
tmpdata = clean_stream.read()
tmpdata = hexlify(tmpdata)
tmpdata = [tmpdata[i:i+80] for i in range(0, len(tmpdata), 80)]
dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher + "\n"
for l in tmpdata:
dirty_data += l + '\n'
dirty_stream = BytesIO(dirty_data)
return dirty_stream
def _string_to_encrypted_file(self, tmpdata, filename):
""" Write a string of data to a file with the format ...
HEADER;VERSION;CIPHER
HEX(ENCRYPTED(SHA256(STRING)+STRING))
"""
# sha256 the data
this_sha = sha256(tmpdata).hexdigest()
# combine sha + data to tmpfile
tmpdata = this_sha + "\n" + tmpdata
src_stream = BytesIO(tmpdata)
dst_stream = BytesIO()
# encrypt tmpfile
self.cipher_obj.encrypt(src_stream, dst_stream, self.password)
# hexlify tmpfile and combine with header
dirty_stream = self._clean_stream_to_dirty_stream(dst_stream)
if os.path.isfile(filename):
os.remove(filename)
# write back to original file
dirty_stream.seek(0)
f = open(filename, "wb")
f.write(dirty_stream.read())
f.close()
def _verify_decryption(self, data):
""" Split data to sha/data and check the sha """
# split the sha and other data
this_sha, clean_data = self._strip_sha(data)
# does the decrypted data match the sha ?
clean_sha = sha256(clean_data).hexdigest()
# compare, return result
if this_sha == clean_sha:
return True
else:
return False
def _strip_sha(self, data):
# is the first line a sha?
lines = data.split("\n")
this_sha = lines[0]
clean_data = '\n'.join(lines[1:])
return this_sha, clean_data
class AES(object):
# http://stackoverflow.com/a/16761459
def __init__(self):
if not HAS_AES:
raise errors.AnsibleError("pycrypto is not installed. Fix this with your package manager, for instance, yum-install python-crypto OR (apt equivalent)")
def aes_derive_key_and_iv(self, password, salt, key_length, iv_length):
""" Create a key and an initialization vector """
d = d_i = ''
while len(d) < key_length + iv_length:
d_i = md5(d_i + password + salt).digest()
d += d_i
key = d[:key_length]
iv = d[key_length:key_length+iv_length]
return key, iv
def encrypt(self, in_file, out_file, password, key_length=32):
""" Read plaintext data from in_file and write encrypted to out_file """
bs = AES_.block_size
# Get a block of random data. EL does not have Crypto.Random.new()
# so os.urandom is used for cross platform purposes
print "WARNING: if encryption hangs, add more entropy (suggest using mouse inputs)"
salt = os.urandom(bs - len('Salted__'))
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES_.new(key, AES_.MODE_CBC, iv)
out_file.write('Salted__' + salt)
finished = False
while not finished:
chunk = in_file.read(1024 * bs)
if len(chunk) == 0 or len(chunk) % bs != 0:
padding_length = (bs - len(chunk) % bs) or bs
chunk += padding_length * chr(padding_length)
finished = True
out_file.write(cipher.encrypt(chunk))
def decrypt(self, in_file, out_file, password, key_length=32):
""" Read encrypted data from in_file and write decrypted to out_file """
# http://stackoverflow.com/a/14989032
bs = AES_.block_size
salt = in_file.read(bs)[len('Salted__'):]
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES_.new(key, AES_.MODE_CBC, iv)
next_chunk = ''
finished = False
out_data = ''
while not finished:
chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs))
if len(next_chunk) == 0:
padding_length = ord(chunk[-1])
chunk = chunk[:-padding_length]
finished = True
out_data += chunk
# write decrypted data to out stream
out_file.write(out_data)
# reset the stream pointer to the beginning
if hasattr(out_file, 'seek'):
out_file.seek(0)
Loading…
Cancel
Save