Tweaked merge_hash to also affect Runner behavior

pull/2583/head
George Miroshnykov 12 years ago
parent 94d189bc7f
commit 6826aa7360

@ -49,11 +49,7 @@ class VarsModule(object):
data = utils.parse_yaml_from_file(path) data = utils.parse_yaml_from_file(path)
if type(data) != dict: if type(data) != dict:
raise errors.AnsibleError("%s must be stored as a dictionary/hash" % path) raise errors.AnsibleError("%s must be stored as a dictionary/hash" % path)
if C.DEFAULT_HASH_BEHAVIOUR == "merge": results = utils.combine_vars(results, data);
# let data content override results if needed
results = utils.merge_hash(results, data)
else:
results.update(data)
# load vars in inventory_dir/hosts_vars/name_of_host # load vars in inventory_dir/hosts_vars/name_of_host
path = os.path.join(basedir, "host_vars/%s" % host.name) path = os.path.join(basedir, "host_vars/%s" % host.name)
@ -61,10 +57,6 @@ class VarsModule(object):
data = utils.parse_yaml_from_file(path) data = utils.parse_yaml_from_file(path)
if type(data) != dict: if type(data) != dict:
raise errors.AnsibleError("%s must be stored as a dictionary/hash" % path) raise errors.AnsibleError("%s must be stored as a dictionary/hash" % path)
if C.DEFAULT_HASH_BEHAVIOUR == "merge": results = utils.combine_vars(results, data);
# let data content override results if needed
results = utils.merge_hash(results, data)
else:
results.update(data)
return results return results

@ -330,8 +330,9 @@ class Play(object):
if host is not None and self._has_vars_in(filename2) and not self._has_vars_in(filename3): if host is not None and self._has_vars_in(filename2) and not self._has_vars_in(filename3):
# running a host specific pass and has host specific variables # running a host specific pass and has host specific variables
# load into setup cache # load into setup cache
self.playbook.SETUP_CACHE[host].update(new_vars) self.playbook.SETUP_CACHE[host] = utils.combine_vars(
self.playbook.SETUP_CACHE[host], new_vars)
self.playbook.callbacks.on_import_for_host(host, filename4) self.playbook.callbacks.on_import_for_host(host, filename4)
elif host is None: elif host is None:
# running a non-host specific pass and we can update the global vars instead # running a non-host specific pass and we can update the global vars instead
self.vars.update(new_vars) self.vars = utils.combine_vars(self.vars, new_vars)

@ -175,7 +175,7 @@ class Runner(object):
# ensure we are using unique tmp paths # ensure we are using unique tmp paths
random.seed() random.seed()
# ***************************************************** # *****************************************************
def _complex_args_hack(self, complex_args, module_args): def _complex_args_hack(self, complex_args, module_args):
@ -333,9 +333,9 @@ class Runner(object):
port = self.remote_port port = self.remote_port
inject = {} inject = {}
inject.update(host_variables) inject = utils.combine_vars(inject, host_variables)
inject.update(self.module_vars) inject = utils.combine_vars(inject, self.module_vars)
inject.update(self.setup_cache[host]) inject = utils.combine_vars(inject, self.setup_cache[host])
inject['hostvars'] = HostVars(self.setup_cache, self.inventory) inject['hostvars'] = HostVars(self.setup_cache, self.inventory)
inject['group_names'] = host_variables.get('group_names', []) inject['group_names'] = host_variables.get('group_names', [])
inject['groups'] = self.inventory.groups_list() inject['groups'] = self.inventory.groups_list()
@ -492,7 +492,7 @@ class Runner(object):
# all modules get a tempdir, action plugins get one unless they have NEEDS_TMPPATH set to False # all modules get a tempdir, action plugins get one unless they have NEEDS_TMPPATH set to False
if getattr(handler, 'NEEDS_TMPPATH', True): if getattr(handler, 'NEEDS_TMPPATH', True):
tmp = self._make_tmp_path(conn) tmp = self._make_tmp_path(conn)
result = handler.run(conn, tmp, module_name, module_args, inject, complex_args) result = handler.run(conn, tmp, module_name, module_args, inject, complex_args)
conn.close() conn.close()
@ -625,8 +625,8 @@ class Runner(object):
module_data = f.read() module_data = f.read()
if module_common.REPLACER in module_data: if module_common.REPLACER in module_data:
is_new_style=True is_new_style=True
complex_args_json = utils.jsonify(complex_args) complex_args_json = utils.jsonify(complex_args)
encoded_args = "\"\"\"%s\"\"\"" % module_args.replace("\"","\\\"") encoded_args = "\"\"\"%s\"\"\"" % module_args.replace("\"","\\\"")
encoded_lang = "\"\"\"%s\"\"\"" % C.DEFAULT_MODULE_LANG encoded_lang = "\"\"\"%s\"\"\"" % C.DEFAULT_MODULE_LANG
encoded_complex = "\"\"\"%s\"\"\"" % complex_args_json.replace("\\", "\\\\") encoded_complex = "\"\"\"%s\"\"\"" % complex_args_json.replace("\\", "\\\\")
@ -635,7 +635,7 @@ class Runner(object):
module_data = module_data.replace(module_common.REPLACER_ARGS, encoded_args) module_data = module_data.replace(module_common.REPLACER_ARGS, encoded_args)
module_data = module_data.replace(module_common.REPLACER_LANG, encoded_lang) module_data = module_data.replace(module_common.REPLACER_LANG, encoded_lang)
module_data = module_data.replace(module_common.REPLACER_COMPLEX, encoded_complex) module_data = module_data.replace(module_common.REPLACER_COMPLEX, encoded_complex)
if is_new_style: if is_new_style:
facility = C.DEFAULT_SYSLOG_FACILITY facility = C.DEFAULT_SYSLOG_FACILITY
if 'ansible_syslog_facility' in inject: if 'ansible_syslog_facility' in inject:
@ -737,7 +737,7 @@ class Runner(object):
# run once per hostgroup, rather than pausing once per each # run once per hostgroup, rather than pausing once per each
# host. # host.
p = utils.plugins.action_loader.get(self.module_name, self) p = utils.plugins.action_loader.get(self.module_name, self)
if p and getattr(p, 'BYPASS_HOST_LOOP', None): if p and getattr(p, 'BYPASS_HOST_LOOP', None):
# Expose the current hostgroup to the bypassing plugins # Expose the current hostgroup to the bypassing plugins

@ -306,7 +306,7 @@ def merge_hash(a, b):
for k, v in b.iteritems(): for k, v in b.iteritems():
if k in a and isinstance(a[k], dict): if k in a and isinstance(a[k], dict):
# if this key is a hash and exists in a # if this key is a hash and exists in a
# we recursively call ourselves with # we recursively call ourselves with
# the key value of b # the key value of b
a[k] = merge_hash(a[k], v) a[k] = merge_hash(a[k], v)
else: else:
@ -663,8 +663,13 @@ def get_diff(diff):
return ">> the files are different, but the diff library cannot compare unicode strings" return ">> the files are different, but the diff library cannot compare unicode strings"
def is_list_of_strings(items): def is_list_of_strings(items):
for x in items: for x in items:
if not isinstance(x, basestring): if not isinstance(x, basestring):
return False return False
return True return True
def combine_vars(a, b):
if C.DEFAULT_HASH_BEHAVIOUR == "merge":
return merge_hash(a, b)
else:
return dict(a.items() + b.items())

@ -10,6 +10,7 @@ import ansible.utils as utils
import ansible.callbacks as ans_callbacks import ansible.callbacks as ans_callbacks
import os import os
import shutil import shutil
import ansible.constants as C
EVENTS = [] EVENTS = []
@ -93,6 +94,8 @@ class TestPlaybook(unittest.TestCase):
os.unlink('/tmp/ansible_test_data_copy.out') os.unlink('/tmp/ansible_test_data_copy.out')
if os.path.exists('/tmp/ansible_test_data_template.out'): if os.path.exists('/tmp/ansible_test_data_template.out'):
os.unlink('/tmp/ansible_test_data_template.out') os.unlink('/tmp/ansible_test_data_template.out')
if os.path.exists('/tmp/ansible_test_messages.out'):
os.unlink('/tmp/ansible_test_messages.out')
def _prepare_stage_dir(self): def _prepare_stage_dir(self):
stage_path = os.path.join(self.test_dir, 'test_data') stage_path = os.path.join(self.test_dir, 'test_data')
@ -236,3 +239,69 @@ class TestPlaybook(unittest.TestCase):
play = ansible.playbook.Play(playbook, playbook.playbook[0], os.getcwd()) play = ansible.playbook.Play(playbook, playbook.playbook[0], os.getcwd())
assert play.hosts == ';'.join(('host1', 'host2', 'host3')) assert play.hosts == ';'.join(('host1', 'host2', 'host3'))
def test_playbook_hash_replace(self):
# save default hash behavior so we can restore it in the end of the test
saved_hash_behavior = C.DEFAULT_HASH_BEHAVIOUR
C.DEFAULT_HASH_BEHAVIOUR = "replace"
test_callbacks = TestCallbacks()
playbook = ansible.playbook.PlayBook(
playbook=os.path.join(self.test_dir, 'test_hash_behavior', 'playbook.yml'),
host_list='test/ansible_hosts',
stats=ans_callbacks.AggregateStats(),
callbacks=test_callbacks,
runner_callbacks=test_callbacks
)
playbook.run()
with open('/tmp/ansible_test_messages.out') as f:
actual = [l.strip() for l in f.readlines()]
print "**ACTUAL**"
print actual
expected = [
"goodbye: Goodbye World!"
]
print "**EXPECTED**"
print expected
assert actual == expected
# restore default hash behavior
C.DEFAULT_HASH_BEHAVIOUR = saved_hash_behavior
def test_playbook_hash_merge(self):
# save default hash behavior so we can restore it in the end of the test
saved_hash_behavior = C.DEFAULT_HASH_BEHAVIOUR
C.DEFAULT_HASH_BEHAVIOUR = "merge"
test_callbacks = TestCallbacks()
playbook = ansible.playbook.PlayBook(
playbook=os.path.join(self.test_dir, 'test_hash_behavior', 'playbook.yml'),
host_list='test/ansible_hosts',
stats=ans_callbacks.AggregateStats(),
callbacks=test_callbacks,
runner_callbacks=test_callbacks
)
playbook.run()
with open('/tmp/ansible_test_messages.out') as f:
actual = [l.strip() for l in f.readlines()]
print "**ACTUAL**"
print actual
expected = [
"hello: Hello World!",
"goodbye: Goodbye World!"
]
print "**EXPECTED**"
print expected
assert actual == expected
# restore default hash behavior
C.DEFAULT_HASH_BEHAVIOUR = saved_hash_behavior

@ -0,0 +1,3 @@
---
messages:
goodbye: "Goodbye World!"

@ -0,0 +1,3 @@
---
messages:
hello: "Hello World!"

@ -0,0 +1,3 @@
{% for k, v in messages.iteritems() %}
{{ k }}: {{ v }}
{% endfor %}

@ -0,0 +1,11 @@
---
- hosts: all
connection: local
vars_files:
- hello.yml
- goodbye.yml
tasks:
- name: generate messages
action: template src=message.j2 dest=/tmp/ansible_test_messages.out
Loading…
Cancel
Save