Perfy McPerferton (#58400)

* InventoryManager start of perf improvements

* 0 not 1

* More startswith to [0] improvements

* Remove unused var

* The hash doesn't need to be a string, start as a list, make it into a tuple

* set actually appears faster than frozenset, and these don't need to be frozen

* Cache hosts lists, to avoid extra get_hosts calls, pass to get_vars too

* negligible perf improvement, it could help with memory later

* Try the fast way, fallback to the safe way

* Revert to previous logic, linting fix

* Extend pre-caching to free

* Address test failures

* Hosts are strings

* Fix unit test

* host is a string

* update test assumption

* drop SharedPluginLoaderObj, pre-create a set, instead of 2 comparisons in the list comprehension

* Dedupe code

* Change to _hosts and _hosts_all in get_vars

* Add backwards compat for strategies that don't do set host caches

* Add deprecation message to SharedPluginLoaderObj

* Remove unused SharedPluginLoaderObj import

* Update docs/comments

* Remove debugging

* Indicate what patterh_hash is

* That won't work

* Re-fix tests

* Update _set_hosts_cache to accept the play directly, use without refresh in get_hosts_remaining and get_failed_hosts for backwards compat

* Rename variable to avoid confusion

* On add_host only manipulate _hosts_cache_all

* Add warning docs around _hosts and _hosts_all args
pull/50901/head
Matt Martz 6 years ago committed by GitHub
parent 6adf0c581e
commit 284dafe476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,6 +48,18 @@ IGNORED_EXTS = [b'%s$' % to_bytes(re.escape(x)) for x in C.INVENTORY_IGNORE_EXTS
IGNORED = re.compile(b'|'.join(IGNORED_ALWAYS + IGNORED_PATTERNS + IGNORED_EXTS)) IGNORED = re.compile(b'|'.join(IGNORED_ALWAYS + IGNORED_PATTERNS + IGNORED_EXTS))
PATTERN_WITH_SUBSCRIPT = re.compile(
r'''^
(.+) # A pattern expression ending with...
\[(?: # A [subscript] expression comprising:
(-?[0-9]+)| # A single positive or negative number
([0-9]+)([:-]) # Or an x:y or x: range.
([0-9]*)
)\]
$
''', re.X
)
def order_patterns(patterns): def order_patterns(patterns):
''' takes a list of patterns and reorders them by modifier to apply them consistently ''' ''' takes a list of patterns and reorders them by modifier to apply them consistently '''
@ -57,9 +69,9 @@ def order_patterns(patterns):
pattern_intersection = [] pattern_intersection = []
pattern_exclude = [] pattern_exclude = []
for p in patterns: for p in patterns:
if p.startswith("!"): if p[0] == "!":
pattern_exclude.append(p) pattern_exclude.append(p)
elif p.startswith("&"): elif p[0] == "&":
pattern_intersection.append(p) pattern_intersection.append(p)
elif p: elif p:
pattern_regular.append(p) pattern_regular.append(p)
@ -316,7 +328,7 @@ class InventoryManager(object):
def _match_list(self, items, pattern_str): def _match_list(self, items, pattern_str):
# compile patterns # compile patterns
try: try:
if not pattern_str.startswith('~'): if not pattern_str[0] == '~':
pattern = re.compile(fnmatch.translate(pattern_str)) pattern = re.compile(fnmatch.translate(pattern_str))
else: else:
pattern = re.compile(pattern_str[1:]) pattern = re.compile(pattern_str[1:])
@ -341,41 +353,45 @@ class InventoryManager(object):
# Check if pattern already computed # Check if pattern already computed
if isinstance(pattern, list): if isinstance(pattern, list):
pattern_hash = u":".join(pattern) pattern_list = pattern[:]
else: else:
pattern_hash = pattern pattern_list = [pattern]
if pattern_hash: if pattern_list:
if not ignore_limits and self._subset: if not ignore_limits and self._subset:
pattern_hash += u":%s" % to_text(self._subset, errors='surrogate_or_strict') pattern_list.extend(self._subset)
if not ignore_restrictions and self._restriction: if not ignore_restrictions and self._restriction:
pattern_hash += u":%s" % to_text(self._restriction, errors='surrogate_or_strict') pattern_list.extend(self._restriction)
# This is only used as a hash key in the self._hosts_patterns_cache dict
# a tuple is faster than stringifying
pattern_hash = tuple(pattern_list)
if pattern_hash not in self._hosts_patterns_cache: if pattern_hash not in self._hosts_patterns_cache:
patterns = split_host_pattern(pattern) patterns = split_host_pattern(pattern)
hosts = self._evaluate_patterns(patterns) hosts[:] = self._evaluate_patterns(patterns)
# mainly useful for hostvars[host] access # mainly useful for hostvars[host] access
if not ignore_limits and self._subset: if not ignore_limits and self._subset:
# exclude hosts not in a subset, if defined # exclude hosts not in a subset, if defined
subset_uuids = [s._uuid for s in self._evaluate_patterns(self._subset)] subset_uuids = set(s._uuid for s in self._evaluate_patterns(self._subset))
hosts = [h for h in hosts if h._uuid in subset_uuids] hosts[:] = [h for h in hosts if h._uuid in subset_uuids]
if not ignore_restrictions and self._restriction: if not ignore_restrictions and self._restriction:
# exclude hosts mentioned in any restriction (ex: failed hosts) # exclude hosts mentioned in any restriction (ex: failed hosts)
hosts = [h for h in hosts if h.name in self._restriction] hosts[:] = [h for h in hosts if h.name in self._restriction]
self._hosts_patterns_cache[pattern_hash] = deduplicate_list(hosts) self._hosts_patterns_cache[pattern_hash] = deduplicate_list(hosts)
# sort hosts list if needed (should only happen when called from strategy) # sort hosts list if needed (should only happen when called from strategy)
if order in ['sorted', 'reverse_sorted']: if order in ['sorted', 'reverse_sorted']:
hosts = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted')) hosts[:] = sorted(self._hosts_patterns_cache[pattern_hash][:], key=attrgetter('name'), reverse=(order == 'reverse_sorted'))
elif order == 'reverse_inventory': elif order == 'reverse_inventory':
hosts = self._hosts_patterns_cache[pattern_hash][::-1] hosts[:] = self._hosts_patterns_cache[pattern_hash][::-1]
else: else:
hosts = self._hosts_patterns_cache[pattern_hash][:] hosts[:] = self._hosts_patterns_cache[pattern_hash][:]
if order == 'shuffle': if order == 'shuffle':
shuffle(hosts) shuffle(hosts)
elif order not in [None, 'inventory']: elif order not in [None, 'inventory']:
@ -398,12 +414,15 @@ class InventoryManager(object):
hosts.append(self._inventory.get_host(p)) hosts.append(self._inventory.get_host(p))
else: else:
that = self._match_one_pattern(p) that = self._match_one_pattern(p)
if p.startswith("!"): if p[0] == "!":
hosts = [h for h in hosts if h not in frozenset(that)] that = set(that)
elif p.startswith("&"): hosts = [h for h in hosts if h not in that]
hosts = [h for h in hosts if h in frozenset(that)] elif p[0] == "&":
that = set(that)
hosts = [h for h in hosts if h in that]
else: else:
hosts.extend([h for h in that if h.name not in frozenset([y.name for y in hosts])]) existing_hosts = set(y.name for y in hosts)
hosts.extend([h for h in that if h.name not in existing_hosts])
return hosts return hosts
def _match_one_pattern(self, pattern): def _match_one_pattern(self, pattern):
@ -444,7 +463,7 @@ class InventoryManager(object):
Duplicate matches are always eliminated from the results. Duplicate matches are always eliminated from the results.
""" """
if pattern.startswith("&") or pattern.startswith("!"): if pattern[0] in ("&", "!"):
pattern = pattern[1:] pattern = pattern[1:]
if pattern not in self._pattern_cache: if pattern not in self._pattern_cache:
@ -469,27 +488,15 @@ class InventoryManager(object):
""" """
# Do not parse regexes for enumeration info # Do not parse regexes for enumeration info
if pattern.startswith('~'): if pattern[0] == '~':
return (pattern, None) return (pattern, None)
# We want a pattern followed by an integer or range subscript. # We want a pattern followed by an integer or range subscript.
# (We can't be more restrictive about the expression because the # (We can't be more restrictive about the expression because the
# fnmatch semantics permit [\[:\]] to occur.) # fnmatch semantics permit [\[:\]] to occur.)
pattern_with_subscript = re.compile(
r'''^
(.+) # A pattern expression ending with...
\[(?: # A [subscript] expression comprising:
(-?[0-9]+)| # A single positive or negative number
([0-9]+)([:-]) # Or an x:y or x: range.
([0-9]*)
)\]
$
''', re.X
)
subscript = None subscript = None
m = pattern_with_subscript.match(pattern) m = PATTERN_WITH_SUBSCRIPT.match(pattern)
if m: if m:
(pattern, idx, start, sep, end) = m.groups() (pattern, idx, start, sep, end) = m.groups()
if idx: if idx:
@ -535,7 +542,7 @@ class InventoryManager(object):
results.extend(self._inventory.groups[groupname].get_hosts()) results.extend(self._inventory.groups[groupname].get_hosts())
# check hosts if no groups matched or it is a regex/glob pattern # check hosts if no groups matched or it is a regex/glob pattern
if not matching_groups or pattern.startswith('~') or any(special in pattern for special in ('.', '?', '*', '[')): if not matching_groups or pattern[0] == '~' or any(special in pattern for special in ('.', '?', '*', '[')):
# pattern might match host # pattern might match host
matching_hosts = self._match_list(self._inventory.hosts, pattern) matching_hosts = self._match_list(self._inventory.hosts, pattern)
if matching_hosts: if matching_hosts:
@ -585,7 +592,7 @@ class InventoryManager(object):
return return
elif not isinstance(restriction, list): elif not isinstance(restriction, list):
restriction = [restriction] restriction = [restriction]
self._restriction = [h.name for h in restriction] self._restriction = set(to_text(h.name) for h in restriction)
def subset(self, subset_pattern): def subset(self, subset_pattern):
""" """
@ -601,12 +608,12 @@ class InventoryManager(object):
results = [] results = []
# allow Unix style @filename data # allow Unix style @filename data
for x in subset_patterns: for x in subset_patterns:
if x.startswith("@"): if x[0] == "@":
fd = open(x[1:]) fd = open(x[1:])
results.extend([l.strip() for l in fd.read().split("\n")]) results.extend([to_text(l.strip()) for l in fd.read().split("\n")])
fd.close() fd.close()
else: else:
results.append(x) results.append(to_text(x))
self._subset = results self._subset = results
def remove_restriction(self): def remove_restriction(self):

@ -10,7 +10,6 @@ from ansible import constants as C
from ansible.plugins.callback import CallbackBase from ansible.plugins.callback import CallbackBase
from ansible.utils.color import colorize, hostcolor from ansible.utils.color import colorize, hostcolor
from ansible.template import Templar from ansible.template import Templar
from ansible.plugins.strategy import SharedPluginLoaderObj
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
DOCUMENTATION = ''' DOCUMENTATION = '''

@ -45,7 +45,7 @@ from ansible.module_utils.connection import Connection, ConnectionError
from ansible.playbook.helpers import load_list_of_blocks from ansible.playbook.helpers import load_list_of_blocks
from ansible.playbook.included_file import IncludedFile from ansible.playbook.included_file import IncludedFile
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
from ansible.plugins.loader import action_loader, connection_loader, filter_loader, lookup_loader, module_loader, test_loader from ansible.plugins import loader as plugin_loader
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.vars import combine_vars from ansible.utils.vars import combine_vars
@ -60,21 +60,12 @@ class StrategySentinel:
pass pass
# TODO: this should probably be in the plugins/__init__.py, with def SharedPluginLoaderObj():
# a smarter mechanism to set all of the attributes based on '''This only exists for backwards compat, do not use.
# the loaders created there
class SharedPluginLoaderObj:
''' '''
A simple object to make pass the various plugin loaders to display.deprecated('SharedPluginLoaderObj is deprecated, please directly use ansible.plugins.loader',
the forked processes over the queue easier version='2.11')
''' return plugin_loader
def __init__(self):
self.action_loader = action_loader
self.connection_loader = connection_loader
self.filter_loader = filter_loader
self.test_loader = test_loader
self.lookup_loader = lookup_loader
self.module_loader = module_loader
_sentinel = StrategySentinel() _sentinel = StrategySentinel()
@ -207,8 +198,29 @@ class StrategyBase:
# play completion # play completion
self._active_connections = dict() self._active_connections = dict()
# Caches for get_host calls, to avoid calling excessively
# These values should be set at the top of the ``run`` method of each
# strategy plugin. Use ``_set_hosts_cache`` to set these values
self._hosts_cache = []
self._hosts_cache_all = []
self.debugger_active = C.ENABLE_TASK_DEBUGGER self.debugger_active = C.ENABLE_TASK_DEBUGGER
def _set_hosts_cache(self, play, refresh=True):
"""Responsible for setting _hosts_cache and _hosts_cache_all
See comment in ``__init__`` for the purpose of these caches
"""
if not refresh and all((self._hosts_cache, self._hosts_cache_all)):
return
if Templar(None).is_template(play.hosts):
_pattern = 'all'
else:
_pattern = play.hosts or 'all'
self._hosts_cache_all = [h.name for h in self._inventory.get_hosts(pattern=_pattern, ignore_restrictions=True)]
self._hosts_cache = [h.name for h in self._inventory.get_hosts(play.hosts, order=play.order)]
def cleanup(self): def cleanup(self):
# close active persistent connections # close active persistent connections
for sock in itervalues(self._active_connections): for sock in itervalues(self._active_connections):
@ -227,8 +239,12 @@ class StrategyBase:
# This should be safe, as everything should be ITERATING_COMPLETE by # This should be safe, as everything should be ITERATING_COMPLETE by
# this point, though the strategy may not advance the hosts itself. # this point, though the strategy may not advance the hosts itself.
inv_hosts = self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order) for host in self._hosts_cache:
[iterator.get_next_task_for_host(host) for host in inv_hosts if host.name not in self._tqm._unreachable_hosts] if host not in self._tqm._unreachable_hosts:
try:
iterator.get_next_task_for_host(self._inventory.hosts[host])
except KeyError:
iterator.get_next_task_for_host(self._inventory.get_host(host))
# save the failed/unreachable hosts, as the run_handlers() # save the failed/unreachable hosts, as the run_handlers()
# method will clear that information during its execution # method will clear that information during its execution
@ -258,19 +274,21 @@ class StrategyBase:
return self._tqm.RUN_OK return self._tqm.RUN_OK
def get_hosts_remaining(self, play): def get_hosts_remaining(self, play):
return [host for host in self._inventory.get_hosts(play.hosts) self._set_hosts_cache(play, refresh=False)
if host.name not in self._tqm._failed_hosts and host.name not in self._tqm._unreachable_hosts] ignore = set(self._tqm._failed_hosts).union(self._tqm._unreachable_hosts)
return [host for host in self._hosts_cache if host not in ignore]
def get_failed_hosts(self, play): def get_failed_hosts(self, play):
return [host for host in self._inventory.get_hosts(play.hosts) if host.name in self._tqm._failed_hosts] self._set_hosts_cache(play, refresh=False)
return [host for host in self._hosts_cache if host in self._tqm._failed_hosts]
def add_tqm_variables(self, vars, play): def add_tqm_variables(self, vars, play):
''' '''
Base class method to add extra variables/information to the list of task Base class method to add extra variables/information to the list of task
vars sent through the executor engine regarding the task queue manager state. vars sent through the executor engine regarding the task queue manager state.
''' '''
vars['ansible_current_hosts'] = [h.name for h in self.get_hosts_remaining(play)] vars['ansible_current_hosts'] = self.get_hosts_remaining(play)
vars['ansible_failed_hosts'] = [h.name for h in self.get_failed_hosts(play)] vars['ansible_failed_hosts'] = self.get_failed_hosts(play)
def _queue_task(self, host, task, task_vars, play_context): def _queue_task(self, host, task, task_vars, play_context):
''' handles queueing the task up to be sent to a worker ''' ''' handles queueing the task up to be sent to a worker '''
@ -294,11 +312,6 @@ class StrategyBase:
# and then queue the new task # and then queue the new task
try: try:
# create a dummy object with plugin loaders set as an easier
# way to share them with the forked processes
shared_loader_obj = SharedPluginLoaderObj()
queued = False queued = False
starting_worker = self._cur_worker starting_worker = self._cur_worker
while True: while True:
@ -311,7 +324,7 @@ class StrategyBase:
'play_context': play_context 'play_context': play_context
} }
worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, shared_loader_obj) worker_prc = WorkerProcess(self._final_q, task_vars, host, task, play_context, self._loader, self._variable_manager, plugin_loader)
self._workers[self._cur_worker] = worker_prc self._workers[self._cur_worker] = worker_prc
self._tqm.send_callback('v2_runner_on_start', host, task) self._tqm.send_callback('v2_runner_on_start', host, task)
worker_prc.start() worker_prc.start()
@ -334,24 +347,19 @@ class StrategyBase:
def get_task_hosts(self, iterator, task_host, task): def get_task_hosts(self, iterator, task_host, task):
if task.run_once: if task.run_once:
host_list = [host for host in self._inventory.get_hosts(iterator._play.hosts) if host.name not in self._tqm._unreachable_hosts] host_list = [host for host in self._hosts_cache if host not in self._tqm._unreachable_hosts]
else: else:
host_list = [task_host] host_list = [task_host.name]
return host_list return host_list
def get_delegated_hosts(self, result, task): def get_delegated_hosts(self, result, task):
host_name = result.get('_ansible_delegated_vars', {}).get('ansible_delegated_host', None) host_name = result.get('_ansible_delegated_vars', {}).get('ansible_delegated_host', None)
if host_name is not None: return [host_name or task.delegate_to]
actual_host = self._inventory.get_host(host_name)
if actual_host is None:
actual_host = Host(name=host_name)
else:
actual_host = Host(name=task.delegate_to)
return [actual_host]
def get_handler_templar(self, handler_task, iterator): def get_handler_templar(self, handler_task, iterator):
handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task) handler_vars = self._variable_manager.get_vars(play=iterator._play, task=handler_task,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
return Templar(loader=self._loader, variables=handler_vars) return Templar(loader=self._loader, variables=handler_vars)
@debug_closure @debug_closure
@ -703,6 +711,7 @@ class StrategyBase:
# Check if host in inventory, add if not # Check if host in inventory, add if not
if host_name not in self._inventory.hosts: if host_name not in self._inventory.hosts:
self._inventory.add_host(host_name, 'all') self._inventory.add_host(host_name, 'all')
self._hosts_cache_all.append(host_name)
new_host = self._inventory.hosts.get(host_name) new_host = self._inventory.hosts.get(host_name)
# Set/update the vars for this host # Set/update the vars for this host
@ -882,7 +891,7 @@ class StrategyBase:
bypass_host_loop = False bypass_host_loop = False
try: try:
action = action_loader.get(handler.action, class_only=True) action = plugin_loader.action_loader.get(handler.action, class_only=True)
if getattr(action, 'BYPASS_HOST_LOOP', False): if getattr(action, 'BYPASS_HOST_LOOP', False):
bypass_host_loop = True bypass_host_loop = True
except KeyError: except KeyError:
@ -893,7 +902,8 @@ class StrategyBase:
host_results = [] host_results = []
for host in notified_hosts: for host in notified_hosts:
if not iterator.is_failed(host) or iterator._play.force_handlers: if not iterator.is_failed(host) or iterator._play.force_handlers:
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler) task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=handler,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play) self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars) templar = Templar(loader=self._loader, variables=task_vars)
if not handler.cached_name: if not handler.cached_name:
@ -993,7 +1003,8 @@ class StrategyBase:
meta_action = task.args.get('_raw_params') meta_action = task.args.get('_raw_params')
def _evaluate_conditional(h): def _evaluate_conditional(h):
all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task) all_vars = self._variable_manager.get_vars(play=iterator._play, host=h, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
templar = Templar(loader=self._loader, variables=all_vars) templar = Templar(loader=self._loader, variables=all_vars)
return task.evaluate_conditional(templar, all_vars) return task.evaluate_conditional(templar, all_vars)
@ -1015,6 +1026,7 @@ class StrategyBase:
if task.when: if task.when:
self._cond_not_supported_warn(meta_action) self._cond_not_supported_warn(meta_action)
self._inventory.refresh_inventory() self._inventory.refresh_inventory()
self._set_hosts_cache(iterator._play)
msg = "inventory successfully refreshed" msg = "inventory successfully refreshed"
elif meta_action == 'clear_facts': elif meta_action == 'clear_facts':
if _evaluate_conditional(target_host): if _evaluate_conditional(target_host):
@ -1047,7 +1059,8 @@ class StrategyBase:
skipped = True skipped = True
msg = "end_host conditional evaluated to false, continuing execution for %s" % target_host.name msg = "end_host conditional evaluated to false, continuing execution for %s" % target_host.name
elif meta_action == 'reset_connection': elif meta_action == 'reset_connection':
all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task) all_vars = self._variable_manager.get_vars(play=iterator._play, host=target_host, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
templar = Templar(loader=self._loader, variables=all_vars) templar = Templar(loader=self._loader, variables=all_vars)
# apply the given task's information to the connection info, # apply the given task's information to the connection info,
@ -1075,7 +1088,7 @@ class StrategyBase:
connection = Connection(self._active_connections[target_host]) connection = Connection(self._active_connections[target_host])
del self._active_connections[target_host] del self._active_connections[target_host]
else: else:
connection = connection_loader.get(play_context.connection, play_context, os.devnull) connection = plugin_loader.connection_loader.get(play_context.connection, play_context, os.devnull)
play_context.set_attributes_from_plugin(connection) play_context.set_attributes_from_plugin(connection)
if connection: if connection:
@ -1104,9 +1117,12 @@ class StrategyBase:
''' returns list of available hosts for this iterator by filtering out unreachables ''' ''' returns list of available hosts for this iterator by filtering out unreachables '''
hosts_left = [] hosts_left = []
for host in self._inventory.get_hosts(iterator._play.hosts, order=iterator._play.order): for host in self._hosts_cache:
if host.name not in self._tqm._unreachable_hosts: if host not in self._tqm._unreachable_hosts:
hosts_left.append(host) try:
hosts_left.append(self._inventory.hosts[host])
except KeyError:
hosts_left.append(self._inventory.get_host(host))
return hosts_left return hosts_left
def update_active_connections(self, results): def update_active_connections(self, results):

@ -82,6 +82,8 @@ class StrategyModule(StrategyBase):
# start with all workers being counted as being free # start with all workers being counted as being free
workers_free = len(self._workers) workers_free = len(self._workers)
self._set_hosts_cache(iterator._play)
work_to_do = True work_to_do = True
while work_to_do and not self._tqm._terminated: while work_to_do and not self._tqm._terminated:
@ -129,7 +131,9 @@ class StrategyModule(StrategyBase):
action = None action = None
display.debug("getting variables", host=host_name) display.debug("getting variables", host=host_name)
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task) task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play) self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars) templar = Templar(loader=self._loader, variables=task_vars)
display.debug("done getting variables", host=host_name) display.debug("done getting variables", host=host_name)
@ -231,7 +235,9 @@ class StrategyModule(StrategyBase):
continue continue
for new_block in new_blocks: for new_block in new_blocks:
task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent) task_vars = self._variable_manager.get_vars(play=iterator._play, task=new_block._parent,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all)
final_block = new_block.filter_tagged_tasks(task_vars) final_block = new_block.filter_tagged_tasks(task_vars)
for host in hosts_left: for host in hosts_left:
if host in included_file._hosts: if host in included_file._hosts:

@ -205,6 +205,9 @@ class StrategyModule(StrategyBase):
# iterate over each task, while there is one left to run # iterate over each task, while there is one left to run
result = self._tqm.RUN_OK result = self._tqm.RUN_OK
work_to_do = True work_to_do = True
self._set_hosts_cache(iterator._play)
while work_to_do and not self._tqm._terminated: while work_to_do and not self._tqm._terminated:
try: try:
@ -275,7 +278,8 @@ class StrategyModule(StrategyBase):
break break
display.debug("getting variables") display.debug("getting variables")
task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task) task_vars = self._variable_manager.get_vars(play=iterator._play, host=host, task=task,
_hosts=self._hosts_cache, _hosts_all=self._hosts_cache_all)
self.add_tqm_variables(task_vars, play=iterator._play) self.add_tqm_variables(task_vars, play=iterator._play)
templar = Templar(loader=self._loader, variables=task_vars) templar = Templar(loader=self._loader, variables=task_vars)
display.debug("done getting variables") display.debug("done getting variables")
@ -358,7 +362,9 @@ class StrategyModule(StrategyBase):
for new_block in new_blocks: for new_block in new_blocks:
task_vars = self._variable_manager.get_vars( task_vars = self._variable_manager.get_vars(
play=iterator._play, play=iterator._play,
task=new_block._parent task=new_block._parent,
_hosts=self._hosts_cache,
_hosts_all=self._hosts_cache_all,
) )
display.debug("filtering new block on tags") display.debug("filtering new block on tags")
final_block = new_block.filter_tagged_tasks(task_vars) final_block = new_block.filter_tagged_tasks(task_vars)

@ -140,7 +140,8 @@ class VariableManager:
def set_inventory(self, inventory): def set_inventory(self, inventory):
self._inventory = inventory self._inventory = inventory
def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True): def get_vars(self, play=None, host=None, task=None, include_hostvars=True, include_delegate_to=True, use_cache=True,
_hosts=None, _hosts_all=None):
''' '''
Returns the variables, with optional "context" given via the parameters Returns the variables, with optional "context" given via the parameters
for the play, host, and task (which could possibly result in different for the play, host, and task (which could possibly result in different
@ -158,6 +159,10 @@ class VariableManager:
- task->get_vars (if there is a task context) - task->get_vars (if there is a task context)
- vars_cache[host] (if there is a host context) - vars_cache[host] (if there is a host context)
- extra vars - extra vars
``_hosts`` and ``_hosts_all`` should be considered private args, with only internal trusted callers relying
on the functionality they provide. These arguments may be removed at a later date without a deprecation
period and without warning.
''' '''
display.debug("in VariableManager get_vars()") display.debug("in VariableManager get_vars()")
@ -169,6 +174,8 @@ class VariableManager:
task=task, task=task,
include_hostvars=include_hostvars, include_hostvars=include_hostvars,
include_delegate_to=include_delegate_to, include_delegate_to=include_delegate_to,
_hosts=_hosts,
_hosts_all=_hosts_all,
) )
# default for all cases # default for all cases
@ -425,7 +432,8 @@ class VariableManager:
display.debug("done with get_vars()") display.debug("done with get_vars()")
return all_vars return all_vars
def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to): def _get_magic_variables(self, play, host, task, include_hostvars, include_delegate_to,
_hosts=None, _hosts_all=None):
''' '''
Returns a dictionary of so-called "magic" variables in Ansible, Returns a dictionary of so-called "magic" variables in Ansible,
which are special variables we set internally for use. which are special variables we set internally for use.
@ -470,9 +478,14 @@ class VariableManager:
else: else:
pattern = play.hosts or 'all' pattern = play.hosts or 'all'
# add the list of hosts in the play, as adjusted for limit/filters # add the list of hosts in the play, as adjusted for limit/filters
variables['ansible_play_hosts_all'] = [x.name for x in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)] if not _hosts_all:
_hosts_all = [h.name for h in self._inventory.get_hosts(pattern=pattern, ignore_restrictions=True)]
if not _hosts:
_hosts = [h.name for h in self._inventory.get_hosts()]
variables['ansible_play_hosts_all'] = _hosts_all[:]
variables['ansible_play_hosts'] = [x for x in variables['ansible_play_hosts_all'] if x not in play._removed_hosts] variables['ansible_play_hosts'] = [x for x in variables['ansible_play_hosts_all'] if x not in play._removed_hosts]
variables['ansible_play_batch'] = [x.name for x in self._inventory.get_hosts() if x.name not in play._removed_hosts] variables['ansible_play_batch'] = [x for x in _hosts if x not in play._removed_hosts]
# DEPRECATED: play_hosts should be deprecated in favor of ansible_play_batch, # DEPRECATED: play_hosts should be deprecated in favor of ansible_play_batch,
# however this would take work in the templating engine, so for now we'll add both # however this would take work in the templating engine, so for now we'll add both
@ -622,19 +635,19 @@ class VariableManager:
raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a Mapping but is a %s" % type(facts)) raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a Mapping but is a %s" % type(facts))
try: try:
host_cache = self._fact_cache[host.name] host_cache = self._fact_cache[host]
except KeyError: except KeyError:
# We get to set this as new # We get to set this as new
host_cache = facts host_cache = facts
else: else:
if not isinstance(host_cache, MutableMapping): if not isinstance(host_cache, MutableMapping):
raise TypeError('The object retrieved for {0} must be a MutableMapping but was' raise TypeError('The object retrieved for {0} must be a MutableMapping but was'
' a {1}'.format(host.name, type(host_cache))) ' a {1}'.format(host, type(host_cache)))
# Update the existing facts # Update the existing facts
host_cache.update(facts) host_cache.update(facts)
# Save the facts back to the backing store # Save the facts back to the backing store
self._fact_cache[host.name] = host_cache self._fact_cache[host] = host_cache
def set_nonpersistent_facts(self, host, facts): def set_nonpersistent_facts(self, host, facts):
''' '''
@ -645,18 +658,17 @@ class VariableManager:
raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a Mapping but is a %s" % type(facts)) raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a Mapping but is a %s" % type(facts))
try: try:
self._nonpersistent_fact_cache[host.name].update(facts) self._nonpersistent_fact_cache[host].update(facts)
except KeyError: except KeyError:
self._nonpersistent_fact_cache[host.name] = facts self._nonpersistent_fact_cache[host] = facts
def set_host_variable(self, host, varname, value): def set_host_variable(self, host, varname, value):
''' '''
Sets a value in the vars_cache for a host. Sets a value in the vars_cache for a host.
''' '''
host_name = host.get_name() if host not in self._vars_cache:
if host_name not in self._vars_cache: self._vars_cache[host] = dict()
self._vars_cache[host_name] = dict() if varname in self._vars_cache[host] and isinstance(self._vars_cache[host][varname], MutableMapping) and isinstance(value, MutableMapping):
if varname in self._vars_cache[host_name] and isinstance(self._vars_cache[host_name][varname], MutableMapping) and isinstance(value, MutableMapping): self._vars_cache[host] = combine_vars(self._vars_cache[host], {varname: value})
self._vars_cache[host_name] = combine_vars(self._vars_cache[host_name], {varname: value})
else: else:
self._vars_cache[host_name][varname] = value self._vars_cache[host][varname] = value

@ -39,7 +39,7 @@ class TestPlaybookCLI(unittest.TestCase):
fake_loader = DictDataLoader({'foobar.yml': ""}) fake_loader = DictDataLoader({'foobar.yml': ""})
inventory = InventoryManager(loader=fake_loader, sources='testhost,') inventory = InventoryManager(loader=fake_loader, sources='testhost,')
variable_manager.set_host_facts(inventory.get_host('testhost'), {'canary': True}) variable_manager.set_host_facts('testhost', {'canary': True})
self.assertTrue('testhost' in variable_manager._fact_cache) self.assertTrue('testhost' in variable_manager._fact_cache)
cli._flush_cache(inventory, variable_manager) cli._flush_cache(inventory, variable_manager)

@ -147,6 +147,8 @@ class TestStrategyBase(unittest.TestCase):
mock_host.has_hostkey = True mock_host.has_hostkey = True
mock_hosts.append(mock_host) mock_hosts.append(mock_host)
mock_hosts_names = [h.name for h in mock_hosts]
mock_inventory = MagicMock() mock_inventory = MagicMock()
mock_inventory.get_hosts.return_value = mock_hosts mock_inventory.get_hosts.return_value = mock_hosts
@ -158,17 +160,18 @@ class TestStrategyBase(unittest.TestCase):
mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)] mock_play.hosts = ["host%02d" % (i + 1) for i in range(0, 5)]
strategy_base = StrategyBase(tqm=mock_tqm) strategy_base = StrategyBase(tqm=mock_tqm)
strategy_base._hosts_cache = strategy_base._hosts_cache_all = mock_hosts_names
mock_tqm._failed_hosts = [] mock_tqm._failed_hosts = []
mock_tqm._unreachable_hosts = [] mock_tqm._unreachable_hosts = []
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts) self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts])
mock_tqm._failed_hosts = ["host01"] mock_tqm._failed_hosts = ["host01"]
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[1:]) self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[1:]])
self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0]]) self.assertEqual(strategy_base.get_failed_hosts(play=mock_play), [mock_hosts[0].name])
mock_tqm._unreachable_hosts = ["host02"] mock_tqm._unreachable_hosts = ["host02"]
self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), mock_hosts[2:]) self.assertEqual(strategy_base.get_hosts_remaining(play=mock_play), [h.name for h in mock_hosts[2:]])
strategy_base.cleanup() strategy_base.cleanup()
@patch.object(WorkerProcess, 'run') @patch.object(WorkerProcess, 'run')

@ -58,18 +58,19 @@ class TestStrategyLinear(unittest.TestCase):
p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager) p = Playbook.load('test_play.yml', loader=fake_loader, variable_manager=mock_var_manager)
inventory = MagicMock()
inventory.hosts = {}
hosts = [] hosts = []
for i in range(0, 2): for i in range(0, 2):
host = MagicMock() host = MagicMock()
host.name = host.get_name.return_value = 'host%02d' % i host.name = host.get_name.return_value = 'host%02d' % i
hosts.append(host) hosts.append(host)
inventory.hosts[host.name] = host
mock_var_manager._fact_cache['host00'] = dict()
inventory = MagicMock()
inventory.get_hosts.return_value = hosts inventory.get_hosts.return_value = hosts
inventory.filter_hosts.return_value = hosts inventory.filter_hosts.return_value = hosts
mock_var_manager._fact_cache['host00'] = dict()
play_context = PlayContext(play=p._entries[0]) play_context = PlayContext(play=p._entries[0])
itr = PlayIterator( itr = PlayIterator(
@ -89,6 +90,8 @@ class TestStrategyLinear(unittest.TestCase):
) )
tqm._initialize_processes(3) tqm._initialize_processes(3)
strategy = StrategyModule(tqm) strategy = StrategyModule(tqm)
strategy._hosts_cache = [h.name for h in hosts]
strategy._hosts_cache_all = [h.name for h in hosts]
# implicit meta: flush_handlers # implicit meta: flush_handlers
hosts_left = strategy.get_hosts_left(itr) hosts_left = strategy.get_hosts_left(itr)

Loading…
Cancel
Save