Merge pull request #6058 from jctanner/vault_squashed_2

Ansible vault: a framework for encrypting any playbook or var file.
pull/6825/head
Richard Isaacson 11 years ago
commit dc403eb21e

@ -84,6 +84,7 @@ class Cli(object):
pattern = args[0] pattern = args[0]
"""
inventory_manager = inventory.Inventory(options.inventory) inventory_manager = inventory.Inventory(options.inventory)
if options.subset: if options.subset:
inventory_manager.subset(options.subset) inventory_manager.subset(options.subset)
@ -101,6 +102,7 @@ class Cli(object):
and not options.module_args): and not options.module_args):
callbacks.display("No argument passed to %s module" % options.module_name, color='red', stderr=True) callbacks.display("No argument passed to %s module" % options.module_name, color='red', stderr=True)
sys.exit(1) sys.exit(1)
"""
sshpass = None sshpass = None
sudopass = None sudopass = None
@ -111,7 +113,27 @@ class Cli(object):
options.ask_pass = False options.ask_pass = False
options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS
options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS
(sshpass, sudopass, su_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass) (sshpass, sudopass, su_pass, vault_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass, ask_vault_pass=options.ask_vault_pass)
inventory_manager = inventory.Inventory(options.inventory)
if options.subset:
inventory_manager.subset(options.subset)
hosts = inventory_manager.list_hosts(pattern)
if len(hosts) == 0:
callbacks.display("No hosts matched")
sys.exit(0)
if options.listhosts:
for host in hosts:
callbacks.display(' %s' % host)
sys.exit(0)
if ((options.module_name == 'command' or options.module_name == 'shell')
and not options.module_args):
callbacks.display("No argument passed to %s module" % options.module_name, color='red', stderr=True)
sys.exit(1)
if options.su_user or options.ask_su_pass: if options.su_user or options.ask_su_pass:
options.su = True options.su = True
elif options.sudo_user or options.ask_sudo_pass: elif options.sudo_user or options.ask_sudo_pass:
@ -121,7 +143,6 @@ class Cli(object):
if options.tree: if options.tree:
utils.prepare_writeable_dir(options.tree) utils.prepare_writeable_dir(options.tree)
runner = Runner( runner = Runner(
module_name=options.module_name, module_name=options.module_name,
module_path=options.module_path, module_path=options.module_path,
@ -143,7 +164,8 @@ class Cli(object):
diff=options.check, diff=options.check,
su=options.su, su=options.su,
su_pass=su_pass, su_pass=su_pass,
su_user=options.su_user su_user=options.su_user,
vault_pass=vault_pass
) )
if options.seconds: if options.seconds:

@ -62,6 +62,8 @@ def main(args):
check_opts=True, check_opts=True,
diff_opts=True diff_opts=True
) )
#parser.add_option('--vault-password', dest="vault_password",
# help="password for vault encrypted files")
parser.add_option('-e', '--extra-vars', dest="extra_vars", action="append", parser.add_option('-e', '--extra-vars', dest="extra_vars", action="append",
help="set additional variables as key=value or YAML/JSON", default=[]) help="set additional variables as key=value or YAML/JSON", default=[])
parser.add_option('-t', '--tags', dest='tags', default='all', parser.add_option('-t', '--tags', dest='tags', default='all',
@ -100,12 +102,13 @@ def main(args):
su_pass = None su_pass = None
if not options.listhosts and not options.syntax and not options.listtasks: if not options.listhosts and not options.syntax and not options.listtasks:
options.ask_pass = options.ask_pass or C.DEFAULT_ASK_PASS options.ask_pass = options.ask_pass or C.DEFAULT_ASK_PASS
options.ask_vault_pass = options.ask_vault_pass or C.DEFAULT_ASK_VAULT_PASS
# Never ask for an SSH password when we run with local connection # Never ask for an SSH password when we run with local connection
if options.connection == "local": if options.connection == "local":
options.ask_pass = False options.ask_pass = False
options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS options.ask_sudo_pass = options.ask_sudo_pass or C.DEFAULT_ASK_SUDO_PASS
options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS options.ask_su_pass = options.ask_su_pass or C.DEFAULT_ASK_SU_PASS
(sshpass, sudopass, su_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass) (sshpass, sudopass, su_pass, vault_pass) = utils.ask_passwords(ask_pass=options.ask_pass, ask_sudo_pass=options.ask_sudo_pass, ask_su_pass=options.ask_su_pass, ask_vault_pass=options.ask_vault_pass)
options.sudo_user = options.sudo_user or C.DEFAULT_SUDO_USER options.sudo_user = options.sudo_user or C.DEFAULT_SUDO_USER
options.su_user = options.su_user or C.DEFAULT_SU_USER options.su_user = options.su_user or C.DEFAULT_SU_USER
@ -170,7 +173,8 @@ def main(args):
diff=options.diff, diff=options.diff,
su=options.su, su=options.su,
su_pass=su_pass, su_pass=su_pass,
su_user=options.su_user su_user=options.su_user,
vault_password=vault_pass
) )
if options.listhosts or options.listtasks or options.syntax: if options.listhosts or options.listtasks or options.syntax:

@ -0,0 +1,187 @@
#!/usr/bin/env python
# (c) 2014, James Tanner <tanner.jc@gmail.com>
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# ansible-pull is a script that runs ansible in local mode
# after checking out a playbooks directory from source repo. There is an
# example playbook to bootstrap this script in the examples/ dir which
# installs ansible and sets it up to run on cron.
import sys
import traceback
from ansible import utils
from ansible import errors
from ansible.utils.vault import *
from ansible.utils.vault import Vault
from optparse import OptionParser
#-------------------------------------------------------------------------------------
# Utility functions for parsing actions/options
#-------------------------------------------------------------------------------------
VALID_ACTIONS = ("create", "decrypt", "edit", "encrypt", "rekey")
def build_option_parser(action):
"""
Builds an option parser object based on the action
the user wants to execute.
"""
usage = "usage: %%prog [%s] [--help] [options] file_name" % "|".join(VALID_ACTIONS)
epilog = "\nSee '%s <command> --help' for more information on a specific command.\n\n" % os.path.basename(sys.argv[0])
OptionParser.format_epilog = lambda self, formatter: self.epilog
parser = OptionParser(usage=usage, epilog=epilog)
if not action:
parser.print_help()
sys.exit()
# options for all actions
#parser.add_option('-p', '--password', help="encryption key")
#parser.add_option('-c', '--cipher', dest='cipher', default="AES", help="cipher to use")
parser.add_option('-d', '--debug', dest='debug', action="store_true", help="debug")
# options specific to actions
if action == "create":
parser.set_usage("usage: %prog create [options] file_name")
elif action == "decrypt":
parser.set_usage("usage: %prog decrypt [options] file_name")
elif action == "edit":
parser.set_usage("usage: %prog edit [options] file_name")
elif action == "encrypt":
parser.set_usage("usage: %prog encrypt [options] file_name")
elif action == "rekey":
parser.set_usage("usage: %prog rekey [options] file_name")
# done, return the parser
return parser
def get_action(args):
"""
Get the action the user wants to execute from the
sys argv list.
"""
for i in range(0,len(args)):
arg = args[i]
if arg in VALID_ACTIONS:
del args[i]
return arg
return None
def get_opt(options, k, defval=""):
"""
Returns an option from an Optparse values instance.
"""
try:
data = getattr(options, k)
except:
return defval
if k == "roles_path":
if os.pathsep in data:
data = data.split(os.pathsep)[0]
return data
#-------------------------------------------------------------------------------------
# Command functions
#-------------------------------------------------------------------------------------
def _get_vault(filename, options, password):
this_vault = Vault()
this_vault.filename = filename
this_vault.vault_password = password
this_vault.password = password
return this_vault
def execute_create(args, options, parser):
if len(args) > 1:
raise errors.AnsibleError("create does not accept more than one filename")
password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, confirm_vault=True)
this_vault = _get_vault(args[0], options, password)
if not hasattr(options, 'cipher'):
this_vault.cipher = 'AES'
this_vault.create()
def execute_decrypt(args, options, parser):
password, new_password = utils.ask_vault_passwords(ask_vault_pass=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.decrypt()
print "Decryption successful"
def execute_edit(args, options, parser):
if len(args) > 1:
raise errors.AnsibleError("create does not accept more than one filename")
password, new_password = utils.ask_vault_passwords(ask_vault_pass=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.edit()
def execute_encrypt(args, options, parser):
password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, confirm_vault=True)
for f in args:
this_vault = _get_vault(f, options, password)
if not hasattr(options, 'cipher'):
this_vault.cipher = 'AES'
this_vault.encrypt()
print "Encryption successful"
def execute_rekey(args, options, parser):
password, new_password = utils.ask_vault_passwords(ask_vault_pass=True, ask_new_vault_pass=True, confirm_new=True)
for f in args:
this_vault = _get_vault(f, options, password)
this_vault.rekey(new_password)
print "Rekey successful"
#-------------------------------------------------------------------------------------
# MAIN
#-------------------------------------------------------------------------------------
def main():
action = get_action(sys.argv)
parser = build_option_parser(action)
(options, args) = parser.parse_args()
# execute the desired action
try:
fn = globals()["execute_%s" % action]
fn(args, options, parser)
except Exception, err:
if options.debug:
print traceback.format_exc()
print "ERROR:",err
sys.exit(1)
if __name__ == "__main__":
main()

@ -117,6 +117,7 @@ DEFAULT_PRIVATE_KEY_FILE = shell_expand_path(get_config(p, DEFAULTS, 'private_k
DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root') DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root')
DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True) DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True)
DEFAULT_REMOTE_PORT = get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', None, integer=True) DEFAULT_REMOTE_PORT = get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', None, integer=True)
DEFAULT_ASK_VAULT_PASS = get_config(p, DEFAULTS, 'ask_vault_pass', 'ANSIBLE_ASK_VAULT_PASS', False, boolean=True)
DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'smart') DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'smart')
DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True) DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True)
DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}') DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}')
@ -172,4 +173,5 @@ DEFAULT_SUDO_PASS = None
DEFAULT_REMOTE_PASS = None DEFAULT_REMOTE_PASS = None
DEFAULT_SUBSET = None DEFAULT_SUBSET = None
DEFAULT_SU_PASS = None DEFAULT_SU_PASS = None
VAULT_VERSION_MIN = 1.0
VAULT_VERSION_MAX = 1.0

@ -347,19 +347,19 @@ class Inventory(object):
raise Exception("group not found: %s" % groupname) raise Exception("group not found: %s" % groupname)
return group.get_variables() return group.get_variables()
def get_variables(self, hostname): def get_variables(self, hostname, vault_password=None):
if hostname not in self._vars_per_host: if hostname not in self._vars_per_host:
self._vars_per_host[hostname] = self._get_variables(hostname) self._vars_per_host[hostname] = self._get_variables(hostname, vault_password=vault_password)
return self._vars_per_host[hostname] return self._vars_per_host[hostname]
def _get_variables(self, hostname): def _get_variables(self, hostname, vault_password=None):
host = self.get_host(hostname) host = self.get_host(hostname)
if host is None: if host is None:
raise errors.AnsibleError("host not found: %s" % hostname) raise errors.AnsibleError("host not found: %s" % hostname)
vars = {} vars = {}
vars_results = [ plugin.run(host) for plugin in self._vars_plugins ] vars_results = [ plugin.run(host, vault_password=vault_password) for plugin in self._vars_plugins ]
for updated in vars_results: for updated in vars_results:
if updated is not None: if updated is not None:
vars.update(updated) vars.update(updated)

@ -23,7 +23,7 @@ from ansible import errors
from ansible import utils from ansible import utils
import ansible.constants as C import ansible.constants as C
def _load_vars(basepath, results): def _load_vars(basepath, results, vault_password=None):
""" """
Load variables from any potential yaml filename combinations of basepath, Load variables from any potential yaml filename combinations of basepath,
returning result. returning result.
@ -35,7 +35,7 @@ def _load_vars(basepath, results):
found_paths = [] found_paths = []
for path in paths_to_check: for path in paths_to_check:
found, results = _load_vars_from_path(path, results) found, results = _load_vars_from_path(path, results, vault_password=vault_password)
if found: if found:
found_paths.append(path) found_paths.append(path)
@ -49,7 +49,7 @@ def _load_vars(basepath, results):
return results return results
def _load_vars_from_path(path, results): def _load_vars_from_path(path, results, vault_password=None):
""" """
Robustly access the file at path and load variables, carefully reporting Robustly access the file at path and load variables, carefully reporting
errors in a friendly/informative way. errors in a friendly/informative way.
@ -90,7 +90,7 @@ def _load_vars_from_path(path, results):
# regular file # regular file
elif stat.S_ISREG(pathstat.st_mode): elif stat.S_ISREG(pathstat.st_mode):
data = utils.parse_yaml_from_file(path) data = utils.parse_yaml_from_file(path, vault_password=vault_password)
if type(data) != dict: if type(data) != dict:
raise errors.AnsibleError( raise errors.AnsibleError(
"%s must be stored as a dictionary/hash" % path) "%s must be stored as a dictionary/hash" % path)
@ -143,7 +143,7 @@ class VarsModule(object):
self.inventory = inventory self.inventory = inventory
def run(self, host): def run(self, host, vault_password=None):
""" main body of the plugin, does actual loading """ """ main body of the plugin, does actual loading """
@ -183,11 +183,11 @@ class VarsModule(object):
# load vars in dir/group_vars/name_of_group # load vars in dir/group_vars/name_of_group
for group in groups: for group in groups:
base_path = os.path.join(basedir, "group_vars/%s" % group) base_path = os.path.join(basedir, "group_vars/%s" % group)
results = _load_vars(base_path, results) results = _load_vars(base_path, results, vault_password=vault_password)
# same for hostvars in dir/host_vars/name_of_host # same for hostvars in dir/host_vars/name_of_host
base_path = os.path.join(basedir, "host_vars/%s" % host.name) base_path = os.path.join(basedir, "host_vars/%s" % host.name)
results = _load_vars(base_path, results) results = _load_vars(base_path, results, vault_password=vault_password)
# all done, results is a dictionary of variables for this particular host. # all done, results is a dictionary of variables for this particular host.
return results return results

@ -72,6 +72,7 @@ class PlayBook(object):
su = False, su = False,
su_user = False, su_user = False,
su_pass = False, su_pass = False,
vault_password = False,
): ):
""" """
@ -138,6 +139,7 @@ class PlayBook(object):
self.su = su self.su = su
self.su_user = su_user self.su_user = su_user
self.su_pass = su_pass self.su_pass = su_pass
self.vault_password = vault_password
self.callbacks.playbook = self self.callbacks.playbook = self
self.runner_callbacks.playbook = self self.runner_callbacks.playbook = self
@ -172,7 +174,7 @@ class PlayBook(object):
run top level error checking on playbooks and allow them to include other playbooks. run top level error checking on playbooks and allow them to include other playbooks.
''' '''
playbook_data = utils.parse_yaml_from_file(path) playbook_data = utils.parse_yaml_from_file(path, vault_password=self.vault_password)
accumulated_plays = [] accumulated_plays = []
play_basedirs = [] play_basedirs = []
@ -242,7 +244,7 @@ class PlayBook(object):
# loop through all patterns and run them # loop through all patterns and run them
self.callbacks.on_start() self.callbacks.on_start()
for (play_ds, play_basedir) in zip(self.playbook, self.play_basedirs): for (play_ds, play_basedir) in zip(self.playbook, self.play_basedirs):
play = Play(self, play_ds, play_basedir) play = Play(self, play_ds, play_basedir, vault_password=self.vault_password)
assert play is not None assert play is not None
matched_tags, unmatched_tags = play.compare_tags(self.only_tags) matched_tags, unmatched_tags = play.compare_tags(self.only_tags)
@ -352,6 +354,7 @@ class PlayBook(object):
su=task.su, su=task.su,
su_user=task.su_user, su_user=task.su_user,
su_pass=task.su_pass, su_pass=task.su_pass,
vault_pass = self.vault_password,
run_hosts=hosts, run_hosts=hosts,
no_log=task.no_log, no_log=task.no_log,
) )
@ -504,6 +507,7 @@ class PlayBook(object):
su=play.su, su=play.su,
su_user=play.su_user, su_user=play.su_user,
su_pass=self.su_pass, su_pass=self.su_pass,
vault_pass=self.vault_password,
transport=play.transport, transport=play.transport,
is_playbook=True, is_playbook=True,
module_vars=play.vars, module_vars=play.vars,
@ -569,9 +573,8 @@ class PlayBook(object):
self._do_setup_step(play) self._do_setup_step(play)
# now with that data, handle contentional variable file imports! # now with that data, handle contentional variable file imports!
all_hosts = self._trim_unavailable_hosts(play._play_hosts) all_hosts = self._trim_unavailable_hosts(play._play_hosts)
play.update_vars_files(all_hosts) play.update_vars_files(all_hosts, vault_password=self.vault_password)
hosts_count = len(all_hosts) hosts_count = len(all_hosts)
serialized_batch = [] serialized_batch = []

@ -34,7 +34,7 @@ class Play(object):
'handlers', 'remote_user', 'remote_port', 'included_roles', 'accelerate', 'handlers', 'remote_user', 'remote_port', 'included_roles', 'accelerate',
'accelerate_port', 'accelerate_ipv6', 'sudo', 'sudo_user', 'transport', 'playbook', 'accelerate_port', 'accelerate_ipv6', 'sudo', 'sudo_user', 'transport', 'playbook',
'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks', 'tags', 'gather_facts', 'serial', '_ds', '_handlers', '_tasks',
'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user' 'basedir', 'any_errors_fatal', 'roles', 'max_fail_pct', '_play_hosts', 'su', 'su_user', 'vault_password'
] ]
# to catch typos and so forth -- these are userland names # to catch typos and so forth -- these are userland names
@ -44,12 +44,12 @@ class Play(object):
'tasks', 'handlers', 'remote_user', 'user', 'port', 'include', 'accelerate', 'accelerate_port', 'accelerate_ipv6', 'tasks', 'handlers', 'remote_user', 'user', 'port', 'include', 'accelerate', 'accelerate_port', 'accelerate_ipv6',
'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial', 'sudo', 'sudo_user', 'connection', 'tags', 'gather_facts', 'serial',
'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage', 'any_errors_fatal', 'roles', 'pre_tasks', 'post_tasks', 'max_fail_percentage',
'su', 'su_user' 'su', 'su_user', 'vault_password'
] ]
# ************************************************* # *************************************************
def __init__(self, playbook, ds, basedir): def __init__(self, playbook, ds, basedir, vault_password=None):
''' constructor loads from a play datastructure ''' ''' constructor loads from a play datastructure '''
for x in ds.keys(): for x in ds.keys():
@ -64,6 +64,7 @@ class Play(object):
self.basedir = basedir self.basedir = basedir
self.roles = ds.get('roles', None) self.roles = ds.get('roles', None)
self.tags = ds.get('tags', None) self.tags = ds.get('tags', None)
self.vault_password = vault_password
if self.tags is None: if self.tags is None:
self.tags = [] self.tags = []
@ -88,6 +89,7 @@ class Play(object):
self.vars_files = ds.get('vars_files', []) self.vars_files = ds.get('vars_files', [])
if not isinstance(self.vars_files, list): if not isinstance(self.vars_files, list):
raise errors.AnsibleError('vars_files must be a list') raise errors.AnsibleError('vars_files must be a list')
self._update_vars_files_for_host(None) self._update_vars_files_for_host(None)
# template everything to be efficient, but do not pre-mature template # template everything to be efficient, but do not pre-mature template
@ -124,6 +126,7 @@ class Play(object):
self.max_fail_pct = int(ds.get('max_fail_percentage', 100)) self.max_fail_pct = int(ds.get('max_fail_percentage', 100))
self.su = ds.get('su', self.playbook.su) self.su = ds.get('su', self.playbook.su)
self.su_user = ds.get('su_user', self.playbook.su_user) self.su_user = ds.get('su_user', self.playbook.su_user)
#self.vault_password = vault_password
# Fail out if user specifies a sudo param with a su param in a given play # Fail out if user specifies a sudo param with a su param in a given play
if (ds.get('sudo') or ds.get('sudo_user')) and (ds.get('su') or ds.get('su_user')): if (ds.get('sudo') or ds.get('sudo_user')) and (ds.get('su') or ds.get('su_user')):
@ -197,7 +200,7 @@ class Play(object):
vars = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'vars'))) vars = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'vars')))
vars_data = {} vars_data = {}
if os.path.isfile(vars): if os.path.isfile(vars):
vars_data = utils.parse_yaml_from_file(vars) vars_data = utils.parse_yaml_from_file(vars, vault_password=self.vault_password)
if vars_data: if vars_data:
if not isinstance(vars_data, dict): if not isinstance(vars_data, dict):
raise errors.AnsibleError("vars from '%s' are not a dict" % vars) raise errors.AnsibleError("vars from '%s' are not a dict" % vars)
@ -205,12 +208,12 @@ class Play(object):
defaults = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'defaults'))) defaults = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'defaults')))
defaults_data = {} defaults_data = {}
if os.path.isfile(defaults): if os.path.isfile(defaults):
defaults_data = utils.parse_yaml_from_file(defaults) defaults_data = utils.parse_yaml_from_file(defaults, vault_password=self.vault_password)
# the meta directory contains the yaml that should # the meta directory contains the yaml that should
# hold the list of dependencies (if any) # hold the list of dependencies (if any)
meta = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'meta'))) meta = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(role_path, 'meta')))
if os.path.isfile(meta): if os.path.isfile(meta):
data = utils.parse_yaml_from_file(meta) data = utils.parse_yaml_from_file(meta, vault_password=self.vault_password)
if data: if data:
dependencies = data.get('dependencies',[]) dependencies = data.get('dependencies',[])
if dependencies is None: if dependencies is None:
@ -220,7 +223,7 @@ class Play(object):
(dep_path,dep_vars) = self._get_role_path(dep) (dep_path,dep_vars) = self._get_role_path(dep)
meta = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'meta'))) meta = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'meta')))
if os.path.isfile(meta): if os.path.isfile(meta):
meta_data = utils.parse_yaml_from_file(meta) meta_data = utils.parse_yaml_from_file(meta, vault_password=self.vault_password)
if meta_data: if meta_data:
allow_dupes = utils.boolean(meta_data.get('allow_duplicates','')) allow_dupes = utils.boolean(meta_data.get('allow_duplicates',''))
@ -241,13 +244,13 @@ class Play(object):
vars = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'vars'))) vars = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'vars')))
vars_data = {} vars_data = {}
if os.path.isfile(vars): if os.path.isfile(vars):
vars_data = utils.parse_yaml_from_file(vars) vars_data = utils.parse_yaml_from_file(vars, vault_password=self.vault_password)
if vars_data: if vars_data:
dep_vars = utils.combine_vars(vars_data, dep_vars) dep_vars = utils.combine_vars(vars_data, dep_vars)
defaults = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'defaults'))) defaults = self._resolve_main(utils.path_dwim(self.basedir, os.path.join(dep_path, 'defaults')))
dep_defaults_data = {} dep_defaults_data = {}
if os.path.isfile(defaults): if os.path.isfile(defaults):
dep_defaults_data = utils.parse_yaml_from_file(defaults) dep_defaults_data = utils.parse_yaml_from_file(defaults, vault_password=self.vault_password)
if 'role' in dep_vars: if 'role' in dep_vars:
del dep_vars['role'] del dep_vars['role']
@ -302,7 +305,7 @@ class Play(object):
default_vars = {} default_vars = {}
for filename in defaults_files: for filename in defaults_files:
if os.path.exists(filename): if os.path.exists(filename):
new_default_vars = utils.parse_yaml_from_file(filename) new_default_vars = utils.parse_yaml_from_file(filename, vault_password=self.vault_password)
if new_default_vars: if new_default_vars:
if type(new_default_vars) != dict: if type(new_default_vars) != dict:
raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename, type(new_default_vars))) raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename, type(new_default_vars)))
@ -540,7 +543,7 @@ class Play(object):
dirname = os.path.dirname(original_file) dirname = os.path.dirname(original_file)
include_file = template(dirname, tokens[0], mv) include_file = template(dirname, tokens[0], mv)
include_filename = utils.path_dwim(dirname, include_file) include_filename = utils.path_dwim(dirname, include_file)
data = utils.parse_yaml_from_file(include_filename) data = utils.parse_yaml_from_file(include_filename, vault_password=self.vault_password)
if 'role_name' in x and data is not None: if 'role_name' in x and data is not None:
for x in data: for x in data:
if 'include' in x: if 'include' in x:
@ -652,12 +655,12 @@ class Play(object):
# ************************************************* # *************************************************
def update_vars_files(self, hosts): def update_vars_files(self, hosts, vault_password=None):
''' calculate vars_files, which requires that setup runs first so ansible facts can be mixed in ''' ''' calculate vars_files, which requires that setup runs first so ansible facts can be mixed in '''
# now loop through all the hosts... # now loop through all the hosts...
for h in hosts: for h in hosts:
self._update_vars_files_for_host(h) self._update_vars_files_for_host(h, vault_password=vault_password)
# ************************************************* # *************************************************
@ -689,14 +692,14 @@ class Play(object):
# ************************************************* # *************************************************
def _update_vars_files_for_host(self, host): def _update_vars_files_for_host(self, host, vault_password=None):
if type(self.vars_files) != list: if type(self.vars_files) != list:
self.vars_files = [ self.vars_files ] self.vars_files = [ self.vars_files ]
if host is not None: if host is not None:
inject = {} inject = {}
inject.update(self.playbook.inventory.get_variables(host)) inject.update(self.playbook.inventory.get_variables(host, vault_password=vault_password))
inject.update(self.playbook.SETUP_CACHE[host]) inject.update(self.playbook.SETUP_CACHE[host])
for filename in self.vars_files: for filename in self.vars_files:
@ -715,7 +718,7 @@ class Play(object):
sequence.append(filename4) sequence.append(filename4)
if os.path.exists(filename4): if os.path.exists(filename4):
found = True found = True
data = utils.parse_yaml_from_file(filename4) data = utils.parse_yaml_from_file(filename4, vault_password=self.vault_password)
if type(data) != dict: if type(data) != dict:
raise errors.AnsibleError("%s must be stored as a dictionary/hash" % filename4) raise errors.AnsibleError("%s must be stored as a dictionary/hash" % filename4)
if host is not None: if host is not None:
@ -747,7 +750,7 @@ class Play(object):
filename4 = utils.path_dwim(self.basedir, filename3) filename4 = utils.path_dwim(self.basedir, filename3)
if self._has_vars_in(filename4): if self._has_vars_in(filename4):
continue continue
new_vars = utils.parse_yaml_from_file(filename4) new_vars = utils.parse_yaml_from_file(filename4, vault_password=self.vault_password)
if new_vars: if new_vars:
if type(new_vars) != dict: if type(new_vars) != dict:
raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename4, type(new_vars))) raise errors.AnsibleError("%s must be stored as dictionary/hash: %s" % (filename4, type(new_vars)))

@ -144,6 +144,7 @@ class Runner(object):
su=False, # Are we running our command via su? su=False, # Are we running our command via su?
su_user=None, # User to su to when running command, ex: 'root' su_user=None, # User to su to when running command, ex: 'root'
su_pass=C.DEFAULT_SU_PASS, su_pass=C.DEFAULT_SU_PASS,
vault_pass=None,
run_hosts=None, # an optional list of pre-calculated hosts to run on run_hosts=None, # an optional list of pre-calculated hosts to run on
no_log=False, # option to enable/disable logging for a given task no_log=False, # option to enable/disable logging for a given task
): ):
@ -197,6 +198,7 @@ class Runner(object):
self.su_user_var = su_user self.su_user_var = su_user
self.su_user = None self.su_user = None
self.su_pass = su_pass self.su_pass = su_pass
self.vault_pass = vault_pass
self.no_log = no_log self.no_log = no_log
if self.transport == 'smart': if self.transport == 'smart':
@ -534,7 +536,7 @@ class Runner(object):
def _executor_internal(self, host, new_stdin): def _executor_internal(self, host, new_stdin):
''' executes any module one or more times ''' ''' executes any module one or more times '''
host_variables = self.inventory.get_variables(host) host_variables = self.inventory.get_variables(host, vault_password=self.vault_pass)
host_connection = host_variables.get('ansible_connection', self.transport) host_connection = host_variables.get('ansible_connection', self.transport)
if host_connection in [ 'paramiko', 'paramiko_alt', 'ssh', 'ssh_old', 'accelerate' ]: if host_connection in [ 'paramiko', 'paramiko_alt', 'ssh', 'ssh_old', 'accelerate' ]:
port = host_variables.get('ansible_ssh_port', self.remote_port) port = host_variables.get('ansible_ssh_port', self.remote_port)

@ -43,7 +43,7 @@ class ActionModule(object):
source = utils.path_dwim(self.runner.basedir, source) source = utils.path_dwim(self.runner.basedir, source)
if os.path.exists(source): if os.path.exists(source):
data = utils.parse_yaml_from_file(source) data = utils.parse_yaml_from_file(source, vault_password=self.runner.vault_pass)
if type(data) != dict: if type(data) != dict:
raise errors.AnsibleError("%s must be stored as a dictionary/hash" % source) raise errors.AnsibleError("%s must be stored as a dictionary/hash" % source)
result = dict(ansible_facts=data) result = dict(ansible_facts=data)

@ -43,6 +43,8 @@ import getpass
import sys import sys
import textwrap import textwrap
import vault
VERBOSITY=0 VERBOSITY=0
# list of all deprecation messages to prevent duplicate display # list of all deprecation messages to prevent duplicate display
@ -494,14 +496,22 @@ Should be written as:
raise errors.AnsibleYAMLValidationFailed(msg) raise errors.AnsibleYAMLValidationFailed(msg)
def parse_yaml_from_file(path): def parse_yaml_from_file(path, vault_password=None):
''' convert a yaml file to a data structure ''' ''' convert a yaml file to a data structure '''
data = None
#VAULT
if vault.is_encrypted(path):
data = vault.decrypt(path, vault_password)
else:
try:
data = open(path).read()
except IOError:
raise errors.AnsibleError("file could not read: %s" % path)
try: try:
data = file(path).read()
return parse_yaml(data) return parse_yaml(data)
except IOError:
raise errors.AnsibleError("file not found: %s" % path)
except yaml.YAMLError, exc: except yaml.YAMLError, exc:
process_yaml_error(exc, data, path) process_yaml_error(exc, data, path)
@ -693,6 +703,8 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
help='ask for sudo password') help='ask for sudo password')
parser.add_option('--ask-su-pass', default=False, dest='ask_su_pass', parser.add_option('--ask-su-pass', default=False, dest='ask_su_pass',
action='store_true', help='ask for su password') action='store_true', help='ask for su password')
parser.add_option('--ask-vault-pass', default=False, dest='ask_vault_pass',
action='store_true', help='ask for vault password')
parser.add_option('--list-hosts', dest='listhosts', action='store_true', parser.add_option('--list-hosts', dest='listhosts', action='store_true',
help='outputs a list of matching hosts; does not execute anything else') help='outputs a list of matching hosts; does not execute anything else')
parser.add_option('-M', '--module-path', dest='module_path', parser.add_option('-M', '--module-path', dest='module_path',
@ -751,10 +763,34 @@ def base_parser(constants=C, usage="", output_opts=False, runas_opts=False,
return parser return parser
def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False): def ask_vault_passwords(ask_vault_pass=False, ask_new_vault_pass=False, confirm_vault=False, confirm_new=False):
vault_pass = None
new_vault_pass = None
if ask_vault_pass:
vault_pass = getpass.getpass(prompt="Vault password: ")
if ask_vault_pass and confirm_vault:
vault_pass2 = getpass.getpass(prompt="Confirm Vault password: ")
if vault_pass != vault_pass2:
raise errors.AnsibleError("Passwords do not match")
if ask_new_vault_pass:
new_vault_pass = getpass.getpass(prompt="New Vault password: ")
if ask_new_vault_pass and confirm_new:
new_vault_pass2 = getpass.getpass(prompt="Confirm New Vault password: ")
if new_vault_pass != new_vault_pass2:
raise errors.AnsibleError("Passwords do not match")
return vault_pass, new_vault_pass
def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False, ask_vault_pass=False):
sshpass = None sshpass = None
sudopass = None sudopass = None
su_pass = None su_pass = None
vault_pass = None
sudo_prompt = "sudo password: " sudo_prompt = "sudo password: "
su_prompt = "su password: " su_prompt = "su password: "
@ -770,7 +806,10 @@ def ask_passwords(ask_pass=False, ask_sudo_pass=False, ask_su_pass=False):
if ask_su_pass: if ask_su_pass:
su_pass = getpass.getpass(prompt=su_prompt) su_pass = getpass.getpass(prompt=su_prompt)
return (sshpass, sudopass, su_pass) if ask_vault_pass:
vault_pass = getpass.getpass(prompt="Vault password: ")
return (sshpass, sudopass, su_pass, vault_pass)
def do_encrypt(result, encrypt, salt_size=None, salt=None): def do_encrypt(result, encrypt, salt_size=None, salt=None):
if PASSLIB_AVAILABLE: if PASSLIB_AVAILABLE:

@ -0,0 +1,445 @@
# (c) 2014, James Tanner <tanner.jc@gmail.com>
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
# ansible-pull is a script that runs ansible in local mode
# after checking out a playbooks directory from source repo. There is an
# example playbook to bootstrap this script in the examples/ dir which
# installs ansible and sets it up to run on cron.
import os
import shutil
import tempfile
from io import BytesIO
from subprocess import call
from ansible import errors
from hashlib import sha256
from hashlib import md5
from binascii import hexlify
from binascii import unhexlify
from ansible import constants as C
# AES IMPORTS
try:
from Crypto.Cipher import AES as AES_
HAS_AES = True
except ImportError:
HAS_AES = False
HEADER='$ANSIBLE_VAULT'
def is_encrypted(filename):
'''
Check a file for the encrypted header and return True or False
The first line should start with the header
defined by the global HEADER. If true, we
assume this is a properly encrypted file.
'''
# read first line of the file
with open(filename) as f:
head = f.next()
if head.startswith(HEADER):
return True
else:
return False
def decrypt(filename, password):
'''
Return a decrypted string of the contents in an encrypted file
This is used by the yaml loading code in ansible
to automatically determine the encryption type
and return a plaintext string of the unencrypted
data.
'''
if password is None:
raise errors.AnsibleError("A vault password must be specified to decrypt %s" % filename)
V = Vault(filename=filename, vault_password=password)
return_data = V._decrypt_to_string()
if not V._verify_decryption(return_data):
raise errors.AnsibleError("Decryption of %s failed" % filename)
this_sha, return_data = V._strip_sha(return_data)
return return_data.strip()
class Vault(object):
def __init__(self, filename=None, cipher=None, vault_password=None):
self.filename = filename
self.vault_password = vault_password
self.cipher = cipher
self.version = '1.0'
###############
# PUBLIC
###############
def eval_header(self):
""" Read first line of the file and parse header """
# read first line
with open(self.filename) as f:
#head=[f.next() for x in xrange(1)]
head = f.next()
this_version = None
this_cipher = None
# split segments
if len(head.split(';')) == 3:
this_version = head.split(';')[1].strip()
this_cipher = head.split(';')[2].strip()
else:
raise errors.AnsibleError("%s has an invalid header" % self.filename)
# validate acceptable version
this_version = float(this_version)
if this_version < C.VAULT_VERSION_MIN or this_version > C.VAULT_VERSION_MAX:
raise errors.AnsibleError("%s must have a version between %s and %s " % (self.filename,
C.VAULT_VERSION_MIN,
C.VAULT_VERSION_MAX))
# set properties
self.cipher = this_cipher
self.version = this_version
def create(self):
""" create a new encrypted file """
if os.path.isfile(self.filename):
raise errors.AnsibleError("%s exists, please use 'edit' instead" % self.filename)
# drop the user into vim on file
EDITOR = os.environ.get('EDITOR','vim')
call([EDITOR, self.filename])
self.encrypt()
def decrypt(self):
""" unencrypt a file inplace """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# set cipher based on file header
self.eval_header()
# decrypt it
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# write back to original file
f = open(self.filename, "wb")
f.write(clean_data)
f.close()
def edit(self, filename=None, password=None, cipher=None, version=None):
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
#decrypt to string
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# rewrite file without sha
_, in_path = tempfile.mkstemp()
f = open(in_path, "wb")
tmpdata = f.write(clean_data)
f.close()
# drop the user into vim on the unencrypted tmp file
EDITOR = os.environ.get('EDITOR','vim')
call([EDITOR, in_path])
f = open(in_path, "rb")
tmpdata = f.read()
f.close()
self._string_to_encrypted_file(tmpdata, self.filename)
def encrypt(self):
""" encrypt a file inplace """
if is_encrypted(self.filename):
raise errors.AnsibleError("%s is already encrypted" % self.filename)
#self.eval_header()
self.__load_cipher()
# read data
f = open(self.filename, "rb")
tmpdata = f.read()
f.close()
self._string_to_encrypted_file(tmpdata, self.filename)
def rekey(self, newpassword):
""" unencrypt file then encrypt with new password """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# unencrypt to string with old password
data = self._decrypt_to_string()
# verify sha and then strip it out
if not self._verify_decryption(data):
raise errors.AnsibleError("decryption of %s failed" % self.filename)
this_sha, clean_data = self._strip_sha(data)
# set password
self.vault_password = newpassword
self._string_to_encrypted_file(clean_data, self.filename)
###############
# PRIVATE
###############
def __load_cipher(self):
"""
Load a cipher class by it's name
This is a lightweight "plugin" implementation to allow
for future support of other cipher types
"""
whitelist = ['AES']
if self.cipher in whitelist:
self.cipher_obj = None
if self.cipher in globals():
this_cipher = globals()[self.cipher]
self.cipher_obj = this_cipher()
else:
raise errors.AnsibleError("%s cipher could not be loaded" % self.cipher)
else:
raise errors.AnsibleError("%s is not an allowed encryption cipher" % self.cipher)
def _decrypt_to_string(self):
""" decrypt file to string """
if not is_encrypted(self.filename):
raise errors.AnsibleError("%s is not encrypted" % self.filename)
# figure out what this is
self.eval_header()
self.__load_cipher()
# strip out header and unhex the file
clean_stream = self._dirty_file_to_clean_file(self.filename)
# reset pointer
clean_stream.seek(0)
# create a byte stream to hold unencrypted
dst = BytesIO()
# decrypt from src stream to dst stream
self.cipher_obj.decrypt(clean_stream, dst, self.vault_password)
# read data from the unencrypted stream
data = dst.read()
return data
def _dirty_file_to_clean_file(self, dirty_filename):
""" Strip out headers from a file, unhex and write to new file"""
_, in_path = tempfile.mkstemp()
#_, out_path = tempfile.mkstemp()
# strip header from data, write rest to tmp file
f = open(dirty_filename, "rb")
tmpdata = f.readlines()
f.close()
tmpheader = tmpdata[0].strip()
tmpdata = ''.join(tmpdata[1:])
# strip out newline, join, unhex
tmpdata = [ x.strip() for x in tmpdata ]
tmpdata = unhexlify(''.join(tmpdata))
# create and return stream
clean_stream = BytesIO(tmpdata)
return clean_stream
def _clean_stream_to_dirty_stream(self, clean_stream):
# combine header and hexlified encrypted data in 80 char columns
clean_stream.seek(0)
tmpdata = clean_stream.read()
tmpdata = hexlify(tmpdata)
tmpdata = [tmpdata[i:i+80] for i in range(0, len(tmpdata), 80)]
dirty_data = HEADER + ";" + str(self.version) + ";" + self.cipher + "\n"
for l in tmpdata:
dirty_data += l + '\n'
dirty_stream = BytesIO(dirty_data)
return dirty_stream
def _string_to_encrypted_file(self, tmpdata, filename):
""" Write a string of data to a file with the format ...
HEADER;VERSION;CIPHER
HEX(ENCRYPTED(SHA256(STRING)+STRING))
"""
# sha256 the data
this_sha = sha256(tmpdata).hexdigest()
# combine sha + data to tmpfile
tmpdata = this_sha + "\n" + tmpdata
src_stream = BytesIO(tmpdata)
dst_stream = BytesIO()
# encrypt tmpfile
self.cipher_obj.encrypt(src_stream, dst_stream, self.password)
# hexlify tmpfile and combine with header
dirty_stream = self._clean_stream_to_dirty_stream(dst_stream)
if os.path.isfile(filename):
os.remove(filename)
# write back to original file
dirty_stream.seek(0)
f = open(filename, "wb")
f.write(dirty_stream.read())
f.close()
def _verify_decryption(self, data):
""" Split data to sha/data and check the sha """
# split the sha and other data
this_sha, clean_data = self._strip_sha(data)
# does the decrypted data match the sha ?
clean_sha = sha256(clean_data).hexdigest()
# compare, return result
if this_sha == clean_sha:
return True
else:
return False
def _strip_sha(self, data):
# is the first line a sha?
lines = data.split("\n")
this_sha = lines[0]
clean_data = '\n'.join(lines[1:])
return this_sha, clean_data
class AES(object):
# http://stackoverflow.com/a/16761459
def __init__(self):
if not HAS_AES:
raise errors.AnsibleError("pycrypto is not installed. Fix this with your package manager, for instance, yum-install python-crypto OR (apt equivalent)")
def aes_derive_key_and_iv(self, password, salt, key_length, iv_length):
""" Create a key and an initialization vector """
d = d_i = ''
while len(d) < key_length + iv_length:
d_i = md5(d_i + password + salt).digest()
d += d_i
key = d[:key_length]
iv = d[key_length:key_length+iv_length]
return key, iv
def encrypt(self, in_file, out_file, password, key_length=32):
""" Read plaintext data from in_file and write encrypted to out_file """
bs = AES_.block_size
# Get a block of random data. EL does not have Crypto.Random.new()
# so os.urandom is used for cross platform purposes
print "WARNING: if encryption hangs, add more entropy (suggest using mouse inputs)"
salt = os.urandom(bs - len('Salted__'))
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES_.new(key, AES_.MODE_CBC, iv)
out_file.write('Salted__' + salt)
finished = False
while not finished:
chunk = in_file.read(1024 * bs)
if len(chunk) == 0 or len(chunk) % bs != 0:
padding_length = (bs - len(chunk) % bs) or bs
chunk += padding_length * chr(padding_length)
finished = True
out_file.write(cipher.encrypt(chunk))
def decrypt(self, in_file, out_file, password, key_length=32):
""" Read encrypted data from in_file and write decrypted to out_file """
# http://stackoverflow.com/a/14989032
bs = AES_.block_size
salt = in_file.read(bs)[len('Salted__'):]
key, iv = self.aes_derive_key_and_iv(password, salt, key_length, bs)
cipher = AES_.new(key, AES_.MODE_CBC, iv)
next_chunk = ''
finished = False
while not finished:
chunk, next_chunk = next_chunk, cipher.decrypt(in_file.read(1024 * bs))
if len(next_chunk) == 0:
padding_length = ord(chunk[-1])
chunk = chunk[:-padding_length]
finished = True
out_file.write(chunk)
# reset the stream pointer to the beginning
if hasattr(out_file, 'seek'):
out_file.seek(0)
Loading…
Cancel
Save