diff --git a/lib/ansible/inventory/manager.py b/lib/ansible/inventory/manager.py index 32ec5ff3ea0..7c328668d11 100644 --- a/lib/ansible/inventory/manager.py +++ b/lib/ansible/inventory/manager.py @@ -48,6 +48,18 @@ IGNORED_EXTS = [b'%s$' % to_bytes(re.escape(x)) for x in C.INVENTORY_IGNORE_EXTS IGNORED = re.compile(b'|'.join(IGNORED_ALWAYS + IGNORED_PATTERNS + IGNORED_EXTS)) +PATTERN_WITH_SUBSCRIPT = re.compile( + r'''^ + (.+) # A pattern expression ending with... + \[(?: # A [subscript] expression comprising: + (-?[0-9]+)| # A single positive or negative number + ([0-9]+)([:-]) # Or an x:y or x: range. + ([0-9]*) + )\] + $ + ''', re.X +) + def order_patterns(patterns): ''' takes a list of patterns and reorders them by modifier to apply them consistently ''' @@ -57,9 +69,9 @@ def order_patterns(patterns): pattern_intersection = [] pattern_exclude = [] for p in patterns: - if p.startswith("!"): + if p[0] == "!": pattern_exclude.append(p) - elif p.startswith("&"): + elif p[0] == "&": pattern_intersection.append(p) elif p: pattern_regular.append(p) @@ -316,7 +328,7 @@ class InventoryManager(object): def _match_list(self, items, pattern_str): # compile patterns try: - if not pattern_str.startswith('~'): + if not pattern_str[0] == '~': pattern = re.compile(fnmatch.translate(pattern_str)) else: pattern = re.compile(pattern_str[1:]) @@ -341,41 +353,45 @@ class InventoryManager(object): # Check if pattern already computed if isinstance(pattern, list): - pattern_hash = u":".join(pattern) + pattern_list = pattern[:] else: - pattern_hash = pattern + pattern_list = [pattern] - if pattern_hash: + if pattern_list: if not ignore_limits and self._subset: - pattern_hash += u":%s" % to_text(self._subset, errors='surrogate_or_strict') + pattern_list.extend(self._subset) if not ignore_restrictions and self._restriction: - pattern_hash += u":%s" % to_text(self._restriction, errors='surrogate_or_strict') + pattern_list.extend(self._restriction) + + # This is only used as a hash key in the self._hosts_patterns_cache dict + # a tuple is faster than stringifying + pattern_hash = tuple(pattern_list) if pattern_hash not in self._hosts_patterns_cache: patterns = split_host_pattern(pattern) - hosts = self._evaluate_patterns(patterns) + hosts[:] = self._evaluate_patterns(patterns) # mainly useful for hostvars[host] access if not ignore_limits and self._subset: # exclude hosts not in a subset, if defined - subset_uuids = [s._uuid for s in self._evaluate_patterns(self._subset)] - hosts = [h for h in hosts if h._uuid in subset_uuids] + subset_uuids = set(s._uuid for s in self._evaluate_patterns(self._subset)) + hosts[:] = [h for h in hosts if h._uuid in subset_uuids] if not ignore_restrictions and self._restriction: # exclude hosts mentioned in any restriction (ex: failed hosts) - hosts = [h for h in hosts if h.name in self._restriction] + hosts[:] = [h for h in hosts if h.name in self._restriction] self._hosts_patterns_cache[pattern_hash] = deduplicate_list(hosts) # sort hosts list if needed (should only happen when called from strategy) if order in ['sorted', 'reverse_sorted']: - hosts = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted')) + hosts[:] = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted')) elif order == 'reverse_inventory': - hosts = self._hosts_patterns_cache[pattern_hash][::-1] + hosts[:] = self._hosts_patterns_cache[pattern_hash][::-1] else: - hosts = self._hosts_patterns_cache[pattern_hash][:] + hosts[:] = self._hosts_patterns_cache[pattern_hash][:] if order == 'shuffle': shuffle(hosts) elif order not in [None, 'inventory']: @@ -398,12 +414,15 @@ class InventoryManager(object): hosts.append(self._inventory.get_host(p)) else: that = self._match_one_pattern(p) - if p.startswith("!"): - hosts = [h for h in hosts if h not in frozenset(that)] - elif p.startswith("&"): - hosts = [h for h in hosts if h in frozenset(that)] + if p[0] == "!": + that = set(that) + hosts = [h for h in hosts if h not in that] + elif p[0] == "&": + that = set(that) + hosts = [h for h in hosts if h in that] else: - hosts.extend([h for h in that if h.name not in frozenset([y.name for y in hosts])]) + existing_hosts = set(y.name for y in hosts) + hosts.extend([h for h in that if h.name not in existing_hosts]) return hosts def _match_one_pattern(self, pattern): @@ -444,7 +463,7 @@ class InventoryManager(object): Duplicate matches are always eliminated from the results. """ - if pattern.startswith("&") or pattern.startswith("!"): + if pattern[0] in ("&", "!"): pattern = pattern[1:] if pattern not in self._pattern_cache: @@ -469,27 +488,15 @@ class InventoryManager(object): """ # Do not parse regexes for enumeration info - if pattern.startswith('~'): + if pattern[0] == '~': return (pattern, None) # We want a pattern followed by an integer or range subscript. # (We can't be more restrictive about the expression because the # fnmatch semantics permit [\[:\]] to occur.) - pattern_with_subscript = re.compile( - r'''^ - (.+) # A pattern expression ending with... - \[(?: # A [subscript] expression comprising: - (-?[0-9]+)| # A single positive or negative number - ([0-9]+)([:-]) # Or an x:y or x: range. - ([0-9]*) - )\] - $ - ''', re.X - ) - subscript = None - m = pattern_with_subscript.match(pattern) + m = PATTERN_WITH_SUBSCRIPT.match(pattern) if m: (pattern, idx, start, sep, end) = m.groups() if idx: @@ -535,7 +542,7 @@ class InventoryManager(object): results.extend(self._inventory.groups[groupname].get_hosts()) # check hosts if no groups matched or it is a regex/glob pattern - if not matching_groups or pattern.startswith('~') or any(special in pattern for special in ('.', '?', '*', '[')): + if not matching_groups or pattern[0] == '~' or any(special in pattern for special in ('.', '?', '*', '[')): # pattern might match host matching_hosts = self._match_list(self._inventory.hosts, pattern) if matching_hosts: @@ -585,7 +592,7 @@ class InventoryManager(object): return elif not isinstance(restriction, list): restriction = [restriction] - self._restriction = [h.name for h in restriction] + self._restriction = set(to_text(h.name) for h in restriction) def subset(self, subset_pattern): """ @@ -601,12 +608,12 @@ class InventoryManager(object): results = [] # allow Unix style @filename data for x in subset_patterns: - if x.startswith("@"): + if x[0] == "@": fd = open(x[1:]) - results.extend([l.strip() for l in fd.read().split("\n")]) + results.extend([to_text(l.strip()) for l in fd.read().split("\n")]) fd.close() else: - results.append(x) + results.append(to_text(x)) self._subset = results def remove_restriction(self): diff --git a/lib/ansible/plugins/callback/counter_enabled.py b/lib/ansible/plugins/callback/counter_enabled.py index a173fe87b69..1c6e6697881 100644 --- a/lib/ansible/plugins/callback/counter_enabled.py +++ b/lib/ansible/plugins/callback/counter_enabled.py @@ -10,7 +10,6 @@ from ansible import constants as C from ansible.plugins.callback import CallbackBase from ansible.utils.color import colorize, hostcolor from ansible.template import Templar -from ansible.plugins.strategy import SharedPluginLoaderObj from ansible.playbook.task_include import TaskInclude DOCUMENTATION = ''' diff --git a/lib/ansible/plugins/strategy/__init__.py b/lib/ansible/plugins/strategy/__init__.py index 8106b0adcfc..e4fe388d2c1 100644 --- a/lib/ansible/plugins/strategy/__init__.py +++ b/lib/ansible/plugins/strategy/__init__.py @@ -45,7 +45,7 @@ from ansible.module_utils.connection import Connection, ConnectionError from ansible.playbook.helpers import load_list_of_blocks from ansible.playbook.included_file import IncludedFile from ansible.playbook.task_include import TaskInclude -from ansible.plugins.loader import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader +from ansible.plugins import loader as plugin_loader from ansible.template import Templar from ansible.utils.display import Display from ansible.utils.vars import combine_vars @@ -60,21 +60,12 @@ class StrategySentinel: pass -# TODO: this should probably be in the plugins/__init__.py, with -# a smarter mechanism to set all of the attributes based on -# the loaders created there -class SharedPluginLoaderObj: +def SharedPluginLoaderObj(): + '''This only exists for backwards compat, do not use. ''' - A simple object to make pass the various plugin loaders to - the forked processes over the queue easier - ''' - def __init__(self): - self.action_loader = action_loader - self.connection_loader = connection_loader - self.filter_loader = filter_loader - self.test_loader = test_loader - self.lookup_loader = lookup_loader - self.module_loader = module_loader + display.deprecated('SharedPluginLoaderObj is deprecated, please directly use ansible.plugins.loader', + version='2.11') + return plugin_loader _sentinel = StrategySentinel() @@ -207,8 +198,29 @@ class StrategyBase: # play completion self._active_connections = dict() + # Caches for get_host calls, to avoid calling excessively + # These values should be set at the top of the ``run`` method of each + # strategy plugin. Use ``_set_hosts_cache`` to set these values + self._hosts_cache = [] + self._hosts_cache_all = [] + self.debugger_active = C.ENABLE_TASK_DEBUGGER + def _set_hosts_cache(self, play, refresh=True): + """Responsible for setting _hosts_cache and _hosts_cache_all + + See comment in ``__init__`` for the purpose of these caches + """ + if not refresh and all((self._hosts_cache, self._hosts_cache_all)): + return + + if Templar(None).is_template(play.hosts): + _pattern = 'all' + else: + _pattern = play.hosts or 'all' + self._hosts_cache_all = [h.name for h in self._inventory.get_hosts(pattern=_pattern, ignore_restrictions=True)] + self._hosts_cache = [h.name for h in self._inventory.get_hosts(play.hosts, order=play.order)] + def cleanup(self): # close active persistent connections for sock in itervalues(self._active_connections): @@ -227,8 +239,12 @@ class StrategyBase: # This should be safe, as everything should be ITERATING_COMPLETE by # this point, though the strategy may not advance the hosts itself. - inv_hosts = self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order) - [iterator.get_next_task_for_host(host) for host in inv_hosts if host.name not in self._tqm._unreachable_hosts] + for host in self._hosts_cache: + if host not in self._tqm._unreachable_hosts: + try: + iterator.get_next_task_for_host(self._inventory.hosts[host]) + except KeyError: + iterator.get_next_task_for_host(self._inventory.get_host(host)) # save the failed/unreachable hosts, as the run_handlers() # method will clear that information during its execution @@ -258,19 +274,21 @@ class StrategyBase: return self._tqm.RUN_OK def get_hosts_remaining(self, play): - return [host for host in self._inventory.get_hosts(play.hosts) - if host.name not in self._tqm._failed_hosts and host.name not in self._tqm._unreachable_hosts] + self._set_hosts_cache(play, refresh=False) + ignore = set(self._tqm._failed_hosts).union(self._tqm._unreachable_hosts) + return [host for host in self._hosts_cache if host not in ignore] def get_failed_hosts(self, play): - return [host for host in self._inventory.get_hosts(play.hosts) if host.name in self._tqm._failed_hosts] + self._set_hosts_cache(play, refresh=False) + return [host for host in self._hosts_cache if host in self._tqm._failed_hosts] def add_tqm_variables(self, vars, play): ''' Base class method to add extra variables/information to the list of task vars sent through the executor engine regarding the task queue manager state. ''' - vars['ansible_current_hosts'] = [h.name for h in self.get_hosts_remaining(play)] - vars['ansible_failed_hosts'] = [h.name for h in self.get_failed_hosts(play)] + vars['ansible_current_hosts'] = self.get_hosts_remaining(play) + vars['ansible_failed_hosts'] = self.get_failed_hosts(play) def _queue_task(self, host, task, task_vars, play_context): ''' handles queueing the task up to be sent to a worker ''' @@ -294,11 +312,6 @@ class StrategyBase: # and then queue the new task try: - - # create a dummy object with plugin loaders set as an easier - # way to share them with the forked processes - shared_loader_obj = SharedPluginLoaderObj() - queued = False starting_worker = self._cur_worker while True: @@ -311,7 +324,7 @@ class StrategyBase: 'play_context': play_context } - worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj) + worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, plugin_loader) self._workers[self._cur_worker] = worker_prc self._tqm.send_callback('v2_runner_on_start', host, task) worker_prc.start() @@ -334,24 +347,19 @@ class StrategyBase: def get_task_hosts(self, iterator, task_host, task): if task.run_once: - host_list = [host for host in self._inventory.get_hosts(iterator._play.hosts) if host.name not in self._tqm._unreachable_hosts] + host_list = [host for host in self._hosts_cache if host not in self._tqm._unreachable_hosts] else: - host_list = [task_host] + host_list = [task_host.name] return host_list def get_delegated_hosts(self, result, task): host_name = result.get('_ansible_delegated_vars', {}).get('ansible_delegated_host', None) - if host_name is not None: - actual_host = self._inventory.get_host(host_name) - if actual_host is None: - actual_host = Host(name=host_name) - else: - actual_host = Host(name=task.delegate_to) - - return [actual_host] + return [host_name or task.delegate_to] def get_handler_templar(self, handler_task, iterator): - handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task) + handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task, + _hosts=self._hosts_cache, + _hosts_all=self._hosts_cache_all) return Templar(loader=self._loader, variables=handler_vars) @debug_closure @@ -703,6 +711,7 @@ class StrategyBase: # Check if host in inventory, add if not if host_name not in self._inventory.hosts: self._inventory.add_host(host_name, 'all') + self._hosts_cache_all.append(host_name) new_host = self._inventory.hosts.get(host_name) # Set/update the vars for this host @@ -882,7 +891,7 @@ class StrategyBase: bypass_host_loop = False try: - action = action_loader.get(handler.action, class_only=True) + action = plugin_loader.action_loader.get(handler.action, class_only=True) if getattr(action, 'BYPASS_HOST_LOOP', False): bypass_host_loop = True except KeyError: @@ -893,7 +902,8 @@ class StrategyBase: host_results = [] for host in notified_hosts: if not iterator.is_failed(host) or iterator._play.force_handlers: - task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler) + task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler, + _hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all) self.add_tqm_variables(task_vars, play=iterator._play) templar = Templar(loader=self._loader, variables=task_vars) if not handler.cached_name: @@ -993,7 +1003,8 @@ class StrategyBase: meta_action = task.args.get('_raw_params') def _evaluate_conditional(h): - all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task) + all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task, + _hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all) templar = Templar(loader=self._loader, variables=all_vars) return task.evaluate_conditional(templar, all_vars) @@ -1015,6 +1026,7 @@ class StrategyBase: if task.when: self._cond_not_supported_warn(meta_action) self._inventory.refresh_inventory() + self._set_hosts_cache(iterator._play) msg = "inventory successfully refreshed" elif meta_action == 'clear_facts': if _evaluate_conditional(target_host): @@ -1047,7 +1059,8 @@ class StrategyBase: skipped = True msg = "end_host conditional evaluated to false, continuing execution for %s" % target_host.name elif meta_action == 'reset_connection': - all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task) + all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task, + _hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all) templar = Templar(loader=self._loader, variables=all_vars) # apply the given task's information to the connection info, @@ -1075,7 +1088,7 @@ class StrategyBase: connection = Connection(self._active_connections[target_host]) del self._active_connections[target_host] else: - connection = connection_loader.get(play_context.connection, play_context, os.devnull) + connection = plugin_loader.connection_loader.get(play_context.connection, play_context, os.devnull) play_context.set_attributes_from_plugin(connection) if connection: @@ -1104,9 +1117,12 @@ class StrategyBase: ''' returns list of available hosts for this iterator by filtering out unreachables ''' hosts_left = [] - for host in self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order): - if host.name not in self._tqm._unreachable_hosts: - hosts_left.append(host) + for host in self._hosts_cache: + if host not in self._tqm._unreachable_hosts: + try: + hosts_left.append(self._inventory.hosts[host]) + except KeyError: + hosts_left.append(self._inventory.get_host(host)) return hosts_left def update_active_connections(self, results): diff --git a/lib/ansible/plugins/strategy/free.py b/lib/ansible/plugins/strategy/free.py index 4e779eb9a1c..ef998321fee 100644 --- a/lib/ansible/plugins/strategy/free.py +++ b/lib/ansible/plugins/strategy/free.py @@ -82,6 +82,8 @@ class StrategyModule(StrategyBase): # start with all workers being counted as being free workers_free = len(self._workers) + self._set_hosts_cache(iterator._play) + work_to_do = True while work_to_do and not self._tqm._terminated: @@ -129,7 +131,9 @@ class StrategyModule(StrategyBase): action = None display.debug("getting variables", host=host_name) - task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task) + task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task, + _hosts=self._hosts_cache, + _hosts_all=self._hosts_cache_all) self.add_tqm_variables(task_vars, play=iterator._play) templar = Templar(loader=self._loader, variables=task_vars) display.debug("done getting variables", host=host_name) @@ -231,7 +235,9 @@ class StrategyModule(StrategyBase): continue for new_block in new_blocks: - task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent) + task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent, + _hosts=self._hosts_cache, + _hosts_all=self._hosts_cache_all) final_block = new_block.filter_tagged_tasks(task_vars) for host in hosts_left: if host in included_file._hosts: diff --git a/lib/ansible/plugins/strategy/linear.py b/lib/ansible/plugins/strategy/linear.py index 34960323299..6883c9c1cb4 100644 --- a/lib/ansible/plugins/strategy/linear.py +++ b/lib/ansible/plugins/strategy/linear.py @@ -205,6 +205,9 @@ class StrategyModule(StrategyBase): # iterate over each task, while there is one left to run result = self._tqm.RUN_OK work_to_do = True + + self._set_hosts_cache(iterator._play) + while work_to_do and not self._tqm._terminated: try: @@ -275,7 +278,8 @@ class StrategyModule(StrategyBase): break display.debug("getting variables") - task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task) + task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task, + _hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all) self.add_tqm_variables(task_vars, play=iterator._play) templar = Templar(loader=self._loader, variables=task_vars) display.debug("done getting variables") @@ -358,7 +362,9 @@ class StrategyModule(StrategyBase): for new_block in new_blocks: task_vars = self._variable_manager.get_vars( play=iterator._play, - task=new_block._parent + task=new_block._parent, + _hosts=self._hosts_cache, + _hosts_all=self._hosts_cache_all, ) display.debug("filtering new block on tags") final_block = new_block.filter_tagged_tasks(task_vars) diff --git a/lib/ansible/vars/manager.py b/lib/ansible/vars/manager.py index 4133b1b53da..3740e02afce 100644 --- a/lib/ansible/vars/manager.py +++ b/lib/ansible/vars/manager.py @@ -140,7 +140,8 @@ class VariableManager: def set_inventory(self, inventory): self._inventory = inventory - def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True): + def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True, + _hosts=None, _hosts_all=None): ''' Returns the variables, with optional "context" given via the parameters for the play, host, and task (which could possibly result in different @@ -158,6 +159,10 @@ class VariableManager: - task->get_vars (if there is a task context) - vars_cache[host] (if there is a host context) - extra vars + + ``_hosts`` and ``_hosts_all`` should be considered private args, with only internal trusted callers relying + on the functionality they provide. These arguments may be removed at a later date without a deprecation + period and without warning. ''' display.debug("in VariableManager get_vars()") @@ -169,6 +174,8 @@ class VariableManager: task=task, include_hostvars=include_hostvars, include_delegate_to=include_delegate_to, + _hosts=_hosts, + _hosts_all=_hosts_all, ) # default for all cases @@ -425,7 +432,8 @@ class VariableManager: display.debug("done with get_vars()") return all_vars - def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to): + def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to, + _hosts=None, _hosts_all=None): ''' Returns a dictionary of so-called "magic" variables in Ansible, which are special variables we set internally for use. @@ -470,9 +478,14 @@ class VariableManager: else: pattern = play.hosts or 'all' # add the list of hosts in the play, as adjusted for limit/filters - variables['ansible_play_hosts_all'] = [x.name for x in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)] + if not _hosts_all: + _hosts_all = [h.name for h in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)] + if not _hosts: + _hosts = [h.name for h in self._inventory.get_hosts()] + + variables['ansible_play_hosts_all'] = _hosts_all[:] variables['ansible_play_hosts'] = [x for x in variables['ansible_play_hosts_all'] if x not in play._removed_hosts] - variables['ansible_play_batch'] = [x.name for x in self._inventory.get_hosts() if x.name not in play._removed_hosts] + variables['ansible_play_batch'] = [x for x in _hosts if x not in play._removed_hosts] # DEPRECATED: play_hosts should be deprecated in favor of ansible_play_batch, # however this would take work in the templating engine, so for now we'll add both @@ -622,19 +635,19 @@ class VariableManager: raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a Mapping but is a %s" % type(facts)) try: - host_cache = self._fact_cache[host.name] + host_cache = self._fact_cache[host] except KeyError: # We get to set this as new host_cache = facts else: if not isinstance(host_cache, MutableMapping): raise TypeError('The object retrieved for {0} must be a MutableMapping but was' - ' a {1}'.format(host.name, type(host_cache))) + ' a {1}'.format(host, type(host_cache))) # Update the existing facts host_cache.update(facts) # Save the facts back to the backing store - self._fact_cache[host.name] = host_cache + self._fact_cache[host] = host_cache def set_nonpersistent_facts(self, host, facts): ''' @@ -645,18 +658,17 @@ class VariableManager: raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a Mapping but is a %s" % type(facts)) try: - self._nonpersistent_fact_cache[host.name].update(facts) + self._nonpersistent_fact_cache[host].update(facts) except KeyError: - self._nonpersistent_fact_cache[host.name] = facts + self._nonpersistent_fact_cache[host] = facts def set_host_variable(self, host, varname, value): ''' Sets a value in the vars_cache for a host. ''' - host_name = host.get_name() - if host_name not in self._vars_cache: - self._vars_cache[host_name] = dict() - if varname in self._vars_cache[host_name] and isinstance(self._vars_cache[host_name][varname], MutableMapping) and isinstance(value, MutableMapping): - self._vars_cache[host_name] = combine_vars(self._vars_cache[host_name], {varname: value}) + if host not in self._vars_cache: + self._vars_cache[host] = dict() + if varname in self._vars_cache[host] and isinstance(self._vars_cache[host][varname], MutableMapping) and isinstance(value, MutableMapping): + self._vars_cache[host] = combine_vars(self._vars_cache[host], {varname: value}) else: - self._vars_cache[host_name][varname] = value + self._vars_cache[host][varname] = value diff --git a/test/units/cli/test_playbook.py b/test/units/cli/test_playbook.py index b3cd402e8cc..f25e54dfc37 100644 --- a/test/units/cli/test_playbook.py +++ b/test/units/cli/test_playbook.py @@ -39,7 +39,7 @@ class TestPlaybookCLI(unittest.TestCase): fake_loader = DictDataLoader({'foobar.yml': ""}) inventory = InventoryManager(loader=fake_loader, sources='testhost,') - variable_manager.set_host_facts(inventory.get_host('testhost'), {'canary': True}) + variable_manager.set_host_facts('testhost', {'canary': True}) self.assertTrue('testhost' in variable_manager._fact_cache) cli._flush_cache(inventory, variable_manager) diff --git a/test/units/plugins/strategy/test_strategy_base.py b/test/units/plugins/strategy/test_strategy_base.py index 541efd0a0bf..841828c9b53 100644 --- a/test/units/plugins/strategy/test_strategy_base.py +++ b/test/units/plugins/strategy/test_strategy_base.py @@ -147,6 +147,8 @@ class TestStrategyBase(unittest.TestCase): mock_host.has_hostkey = True mock_hosts.append(mock_host) + mock_hosts_names = [h.name for h in mock_hosts] + mock_inventory = MagicMock() mock_inventory.get_hosts.return_value = mock_hosts @@ -158,17 +160,18 @@ class TestStrategyBase(unittest.TestCase): mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)] strategy_base = StrategyBase(tqm=mock_tqm) + strategy_base._hosts_cache = strategy_base._hosts_cache_all = mock_hosts_names mock_tqm._failed_hosts = [] mock_tqm._unreachable_hosts = [] - self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts) + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts]) mock_tqm._failed_hosts = ["host01"] - self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[1:]) - self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0]]) + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[1:]]) + self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0].name]) mock_tqm._unreachable_hosts = ["host02"] - self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[2:]) + self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[2:]]) strategy_base.cleanup() @patch.object(WorkerProcess, 'run') diff --git a/test/units/plugins/strategy/test_strategy_linear.py b/test/units/plugins/strategy/test_strategy_linear.py index 1c1624a6638..df95c2676c5 100644 --- a/test/units/plugins/strategy/test_strategy_linear.py +++ b/test/units/plugins/strategy/test_strategy_linear.py @@ -58,18 +58,19 @@ class TestStrategyLinear(unittest.TestCase): p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager) + inventory = MagicMock() + inventory.hosts = {} hosts = [] for i in range(0, 2): host = MagicMock() host.name = host.get_name.return_value = 'host%02d' % i hosts.append(host) - - mock_var_manager._fact_cache['host00'] = dict() - - inventory = MagicMock() + inventory.hosts[host.name] = host inventory.get_hosts.return_value = hosts inventory.filter_hosts.return_value = hosts + mock_var_manager._fact_cache['host00'] = dict() + play_context = PlayContext(play=p._entries[0]) itr = PlayIterator( @@ -89,6 +90,8 @@ class TestStrategyLinear(unittest.TestCase): ) tqm._initialize_processes(3) strategy = StrategyModule(tqm) + strategy._hosts_cache = [h.name for h in hosts] + strategy._hosts_cache_all = [h.name for h in hosts] # implicit meta: flush_handlers hosts_left = strategy.get_hosts_left(itr)