fix various Jinja plugin caching issues (#79781)

* fix various Jinja plugin caching issues

* consolidate the wrapper plugin cache
* remove redundant cache in J2 filter/test interceptor

* intra-template loader bypass

* fix early exits swallowing some exception detail

* misc comment cleanup
pull/67103/merge
Matt Davis 1 year ago committed by GitHub
parent 4d40988876
commit dd79c49a4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- PluginLoader - fix Jinja plugin performance issues (https://github.com/ansible/ansible/issues/79652)

@ -18,6 +18,10 @@ from collections import defaultdict, namedtuple
from traceback import format_exc from traceback import format_exc
import ansible.module_utils.compat.typing as t import ansible.module_utils.compat.typing as t
from .filter import AnsibleJinja2Filter
from .test import AnsibleJinja2Test
from ansible import __version__ as ansible_version from ansible import __version__ as ansible_version
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsiblePluginCircularRedirect, AnsiblePluginRemovedError, AnsibleCollectionUnsupportedVersionError from ansible.errors import AnsibleError, AnsiblePluginCircularRedirect, AnsiblePluginRemovedError, AnsibleCollectionUnsupportedVersionError
@ -1067,28 +1071,17 @@ class Jinja2Loader(PluginLoader):
We need to do a few things differently in the base class because of file == plugin We need to do a few things differently in the base class because of file == plugin
assumptions and dedupe logic. assumptions and dedupe logic.
""" """
def __init__(self, class_name, package, config, subdir, aliases=None, required_base_class=None): def __init__(self, class_name, package, config, subdir, plugin_wrapper_type, aliases=None, required_base_class=None):
super(Jinja2Loader, self).__init__(class_name, package, config, subdir, aliases=aliases, required_base_class=required_base_class) super(Jinja2Loader, self).__init__(class_name, package, config, subdir, aliases=aliases, required_base_class=required_base_class)
self._loaded_j2_file_maps = [] self._plugin_wrapper_type = plugin_wrapper_type
self._cached_non_collection_wrappers = {}
def _clear_caches(self): def _clear_caches(self):
super(Jinja2Loader, self)._clear_caches() super(Jinja2Loader, self)._clear_caches()
self._loaded_j2_file_maps = [] self._cached_non_collection_wrappers = {}
def find_plugin(self, name, mod_type='', ignore_deprecated=False, check_aliases=False, collection_list=None): def find_plugin(self, name, mod_type='', ignore_deprecated=False, check_aliases=False, collection_list=None):
raise NotImplementedError('find_plugin is not supported on Jinja2Loader')
# TODO: handle collection plugin find, see 'get_with_context'
# this can really 'find plugin file'
plugin = super(Jinja2Loader, self).find_plugin(name, mod_type=mod_type, ignore_deprecated=ignore_deprecated, check_aliases=check_aliases,
collection_list=collection_list)
# if not found, try loading all non collection plugins and see if this in there
if not plugin:
all_plugins = self.all()
plugin = all_plugins.get(name, None)
return plugin
@property @property
def method_map_name(self): def method_map_name(self):
@ -1122,8 +1115,7 @@ class Jinja2Loader(PluginLoader):
for func_name, func in plugin_map: for func_name, func in plugin_map:
fq_name = '.'.join((collection, func_name)) fq_name = '.'.join((collection, func_name))
full = '.'.join((full_name, func_name)) full = '.'.join((full_name, func_name))
pclass = self._load_jinja2_class() plugin = self._plugin_wrapper_type(func)
plugin = pclass(func)
if plugin in plugins: if plugin in plugins:
continue continue
self._update_object(plugin, full, plugin_path, resolved=fq_name) self._update_object(plugin, full, plugin_path, resolved=fq_name)
@ -1131,21 +1123,22 @@ class Jinja2Loader(PluginLoader):
return plugins return plugins
# FUTURE: now that the resulting plugins are closer, refactor base class method with some extra
# hooks so we can avoid all the duplicated plugin metadata logic, and also cache the collection results properly here
def get_with_context(self, name, *args, **kwargs): def get_with_context(self, name, *args, **kwargs):
# pop N/A kwargs to avoid passthrough to parent methods
# found_in_cache = True kwargs.pop('class_only', False)
class_only = kwargs.pop('class_only', False) # just pop it, dont want to pass through kwargs.pop('collection_list', None)
collection_list = kwargs.pop('collection_list', None)
context = PluginLoadContext() context = PluginLoadContext()
# avoid collection path for legacy # avoid collection path for legacy
name = name.removeprefix('ansible.legacy.') name = name.removeprefix('ansible.legacy.')
if '.' not in name: self._ensure_non_collection_wrappers(*args, **kwargs)
# Filter/tests must always be FQCN except builtin and legacy
for known_plugin in self.all(*args, **kwargs): # check for stuff loaded via legacy/builtin paths first
if known_plugin.matches_name([name]): if known_plugin := self._cached_non_collection_wrappers.get(name):
context.resolved = True context.resolved = True
context.plugin_resolved_name = name context.plugin_resolved_name = name
context.plugin_resolved_path = known_plugin._original_path context.plugin_resolved_path = known_plugin._original_path
@ -1237,14 +1230,10 @@ class Jinja2Loader(PluginLoader):
# use 'parent' loader class to find files, but cannot return this as it can contain # use 'parent' loader class to find files, but cannot return this as it can contain
# multiple plugins per file # multiple plugins per file
plugin_impl = super(Jinja2Loader, self).get_with_context(module_name, *args, **kwargs) plugin_impl = super(Jinja2Loader, self).get_with_context(module_name, *args, **kwargs)
except Exception as e:
raise KeyError(to_native(e))
try:
method_map = getattr(plugin_impl.object, self.method_map_name) method_map = getattr(plugin_impl.object, self.method_map_name)
plugin_map = method_map().items() plugin_map = method_map().items()
except Exception as e: except Exception as e:
display.warning("Skipping %s plugins in '%s' as it seems to be invalid: %r" % (self.type, to_text(plugin_impl.object._original_path), e)) display.warning(f"Skipping {self.type} plugins in {module_name}'; an error occurred while loading: {e}")
continue continue
for func_name, func in plugin_map: for func_name, func in plugin_map:
@ -1253,11 +1242,11 @@ class Jinja2Loader(PluginLoader):
# TODO: load anyways into CACHE so we only match each at end of loop # TODO: load anyways into CACHE so we only match each at end of loop
# the files themseves should already be cached by base class caching of modules(python) # the files themseves should already be cached by base class caching of modules(python)
if key in (func_name, fq_name): if key in (func_name, fq_name):
pclass = self._load_jinja2_class() plugin = self._plugin_wrapper_type(func)
plugin = pclass(func)
if plugin: if plugin:
context = plugin_impl.plugin_load_context context = plugin_impl.plugin_load_context
self._update_object(plugin, src_name, plugin_impl.object._original_path, resolved=fq_name) self._update_object(plugin, src_name, plugin_impl.object._original_path, resolved=fq_name)
# FIXME: once we start caching these results, we'll be missing functions that would have loaded later
break # go to next file as it can override if dupe (dont break both loops) break # go to next file as it can override if dupe (dont break both loops)
except AnsiblePluginRemovedError as apre: except AnsiblePluginRemovedError as apre:
@ -1272,8 +1261,7 @@ class Jinja2Loader(PluginLoader):
return get_with_context_result(plugin, context) return get_with_context_result(plugin, context)
def all(self, *args, **kwargs): def all(self, *args, **kwargs):
kwargs.pop('_dedupe', None)
# inputs, we ignore 'dedupe' we always do, used in base class to find files for this one
path_only = kwargs.pop('path_only', False) path_only = kwargs.pop('path_only', False)
class_only = kwargs.pop('class_only', False) # basically ignored for test/filters since they are functions class_only = kwargs.pop('class_only', False) # basically ignored for test/filters since they are functions
@ -1281,9 +1269,19 @@ class Jinja2Loader(PluginLoader):
if path_only and class_only: if path_only and class_only:
raise AnsibleError('Do not set both path_only and class_only when calling PluginLoader.all()') raise AnsibleError('Do not set both path_only and class_only when calling PluginLoader.all()')
found = set() self._ensure_non_collection_wrappers(*args, **kwargs)
if path_only:
yield from (w._original_path for w in self._cached_non_collection_wrappers.values())
else:
yield from (w for w in self._cached_non_collection_wrappers.values())
def _ensure_non_collection_wrappers(self, *args, **kwargs):
if self._cached_non_collection_wrappers:
return
# get plugins from files in configured paths (multiple in each) # get plugins from files in configured paths (multiple in each)
for p_map in self._j2_all_file_maps(*args, **kwargs): for p_map in super(Jinja2Loader, self).all(*args, **kwargs):
is_builtin = p_map.ansible_name.startswith('ansible.builtin.')
# p_map is really object from file with class that holds multiple plugins # p_map is really object from file with class that holds multiple plugins
plugins_list = getattr(p_map, self.method_map_name) plugins_list = getattr(p_map, self.method_map_name)
@ -1294,57 +1292,35 @@ class Jinja2Loader(PluginLoader):
continue continue
for plugin_name in plugins.keys(): for plugin_name in plugins.keys():
if plugin_name in _PLUGIN_FILTERS[self.package]: if '.' in plugin_name:
display.debug("%s skipped due to a defined plugin filter" % plugin_name) display.debug(f'{plugin_name} skipped in {p_map._original_path}; Jinja plugin short names may not contain "."')
continue continue
if plugin_name in found: if plugin_name in _PLUGIN_FILTERS[self.package]:
display.debug("%s skipped as duplicate" % plugin_name) display.debug("%s skipped due to a defined plugin filter" % plugin_name)
continue continue
if path_only: # the plugin class returned by the loader may host multiple Jinja plugins, but we wrap each plugin in
result = p_map._original_path # its own surrogate wrapper instance here to ease the bookkeeping...
else: wrapper = self._plugin_wrapper_type(plugins[plugin_name])
# loader class is for the file with multiple plugins, but each plugin now has it's own class
pclass = self._load_jinja2_class()
result = pclass(plugins[plugin_name]) # if bad plugin, let exception rise
found.add(plugin_name)
fqcn = plugin_name fqcn = plugin_name
collection = '.'.join(p_map.ansible_name.split('.')[:2]) if p_map.ansible_name.count('.') >= 2 else '' collection = '.'.join(p_map.ansible_name.split('.')[:2]) if p_map.ansible_name.count('.') >= 2 else ''
if not plugin_name.startswith(collection): if not plugin_name.startswith(collection):
fqcn = f"{collection}.{plugin_name}" fqcn = f"{collection}.{plugin_name}"
self._update_object(result, plugin_name, p_map._original_path, resolved=fqcn) self._update_object(wrapper, plugin_name, p_map._original_path, resolved=fqcn)
yield result
def _load_jinja2_class(self):
""" override the normal method of plugin classname as these are used in the generic funciton
to access the 'multimap' of filter/tests to function, this is a 'singular' plugin for
each entry.
"""
class_name = 'AnsibleJinja2%s' % get_plugin_class(self.class_name).capitalize()
module = __import__(self.package, fromlist=[class_name])
return getattr(module, class_name)
def _j2_all_file_maps(self, *args, **kwargs): target_names = {plugin_name, fqcn}
""" if is_builtin:
* Unlike other plugin types, file != plugin, a file can contain multiple plugins (of same type). target_names.add(f'ansible.builtin.{plugin_name}')
This is why we do not deduplicate ansible file names at this point, we mostly care about
the names of the actual jinja2 plugins which are inside of our files.
* This method will NOT fetch collection plugin files, only those that would be expected under 'ansible.builtin/legacy'.
"""
# populate cache if needed
if not self._loaded_j2_file_maps:
# We don't deduplicate ansible file names. for target_name in target_names:
# Instead, calling code deduplicates jinja2 plugin names when loading each file. if existing_plugin := self._cached_non_collection_wrappers.get(target_name):
kwargs['_dedupe'] = False display.debug(f'Jinja plugin {target_name} from {p_map._original_path} skipped; '
f'shadowed by plugin from {existing_plugin._original_path})')
# To match correct precedence, call base class' all() to get a list of files, continue
self._loaded_j2_file_maps = list(super(Jinja2Loader, self).all(*args, **kwargs))
return self._loaded_j2_file_maps self._cached_non_collection_wrappers[target_name] = wrapper
def get_fqcr_and_name(resource, collection='ansible.builtin'): def get_fqcr_and_name(resource, collection='ansible.builtin'):
@ -1572,13 +1548,15 @@ filter_loader = Jinja2Loader(
'ansible.plugins.filter', 'ansible.plugins.filter',
C.DEFAULT_FILTER_PLUGIN_PATH, C.DEFAULT_FILTER_PLUGIN_PATH,
'filter_plugins', 'filter_plugins',
AnsibleJinja2Filter
) )
test_loader = Jinja2Loader( test_loader = Jinja2Loader(
'TestModule', 'TestModule',
'ansible.plugins.test', 'ansible.plugins.test',
C.DEFAULT_TEST_PLUGIN_PATH, C.DEFAULT_TEST_PLUGIN_PATH,
'test_plugins' 'test_plugins',
AnsibleJinja2Test
) )
strategy_loader = PluginLoader( strategy_loader = PluginLoader(

@ -445,11 +445,11 @@ class JinjaPluginIntercept(MutableMapping):
self._pluginloader = pluginloader self._pluginloader = pluginloader
# cache of resolved plugins # Jinja environment's mapping of known names (initially just J2 builtins)
self._delegatee = delegatee self._delegatee = delegatee
# track loaded plugins here as cache above includes 'jinja2' filters but ours should override # our names take precedence over Jinja's, but let things we've tried to resolve skip the pluginloader
self._loaded_builtins = set() self._seen_it = set()
def __getitem__(self, key): def __getitem__(self, key):
@ -457,7 +457,10 @@ class JinjaPluginIntercept(MutableMapping):
raise ValueError('key must be a string, got %s instead' % type(key)) raise ValueError('key must be a string, got %s instead' % type(key))
original_exc = None original_exc = None
if key not in self._loaded_builtins: if key not in self._seen_it:
# this looks too early to set this- it isn't. Setting it here keeps requests for Jinja builtins from
# going through the pluginloader more than once, which is extremely slow for something that won't ever succeed.
self._seen_it.add(key)
plugin = None plugin = None
try: try:
plugin = self._pluginloader.get(key) plugin = self._pluginloader.get(key)
@ -471,12 +474,12 @@ class JinjaPluginIntercept(MutableMapping):
if plugin: if plugin:
# set in filter cache and avoid expensive plugin load # set in filter cache and avoid expensive plugin load
self._delegatee[key] = plugin.j2_function self._delegatee[key] = plugin.j2_function
self._loaded_builtins.add(key)
# raise template syntax error if we could not find ours or jinja2 one # raise template syntax error if we could not find ours or jinja2 one
try: try:
func = self._delegatee[key] func = self._delegatee[key]
except KeyError as e: except KeyError as e:
self._seen_it.remove(key)
raise TemplateSyntaxError('Could not load "%s": %s' % (key, to_native(original_exc or e)), 0) raise TemplateSyntaxError('Could not load "%s": %s' % (key, to_native(original_exc or e)), 0)
# if i do have func and it is a filter, it nees wrapping # if i do have func and it is a filter, it nees wrapping

@ -151,8 +151,8 @@
- assert: - assert:
that: that:
- | # FUTURE: ensure that the warning was also issued with the actual failure details
'This is a broken filter plugin.' in result.msg - result is failed
- debug: - debug:
msg: "{{ 'foo'|missing.collection.filter }}" msg: "{{ 'foo'|missing.collection.filter }}"

Loading…
Cancel
Save