Override from_string, instead of sprinkling code in do_template

pull/82224/head
Matt Martz 1 year ago
parent 5030a530d6
commit d926de5906
No known key found for this signature in database
GPG Key ID: 40832D88E9FC91D8

@ -30,7 +30,9 @@ from contextlib import contextmanager
from numbers import Number
from traceback import format_exc
from jinja2 import nodes
from jinja2.bccache import FileSystemBytecodeCache
from jinja2.environment import Template
from jinja2.exceptions import TemplateSyntaxError, UndefinedError
from jinja2.loaders import FileSystemLoader
from jinja2.nativetypes import NativeEnvironment
@ -48,6 +50,7 @@ from ansible.errors import (
from ansible.module_utils.six import string_types
from ansible.module_utils.common.text.converters import to_native, to_text, to_bytes
from ansible.module_utils.common.collections import is_sequence
from ansible.module_utils.compat import typing as t
from ansible.plugins.loader import filter_loader, lookup_loader, test_loader
from ansible.template.native_helpers import ansible_native_concat, ansible_eval_concat, ansible_concat
from ansible.template.template import AnsibleJ2Template
@ -539,6 +542,11 @@ class AnsibleEnvironment(NativeEnvironment):
context_class = AnsibleContext
template_class = AnsibleJ2Template
concat = staticmethod(ansible_eval_concat) # type: ignore[assignment]
_overrides = (
'autoescape', 'block_end_string', 'block_start_string', 'comment_end_string', 'comment_start_string',
'keep_trailing_newline', 'line_comment_prefix', 'line_statement_prefix', 'lstrip_blocks', 'newline_sequence',
'trim_blocks', 'variable_end_string', 'variable_start_string',
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -551,6 +559,32 @@ class AnsibleEnvironment(NativeEnvironment):
self.undefined = AnsibleUndefined
self.finalize = _ansible_finalize
def _make_bucket_name(self, source):
override = ';'.join(
f'{k}={getattr(self, k, None)!r}' for k in self._overrides
) + f';native={isinstance(self, AnsibleNativeEnvironment)};'
return f'{source}[overrides={override}]'
def from_string(
self,
source: str, # type: ignore[override]
globals: t.MutableMapping[str, t.Any] | None = None,
template_class: t.Type[Template] | None = None,
) -> Template:
cache_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'j2cache')
if not os.path.isdir(cache_dir):
os.makedirs(cache_dir, mode=0o700, exist_ok=True)
bcc = FileSystemBytecodeCache(cache_dir, '__ansible_j2_%s.cache')
bucket = bcc.get_bucket(self, self._make_bucket_name(source), None, source)
if not bucket.code:
bucket.code = self.compile(source)
bcc.set_bucket(bucket)
cls = template_class or self.template_class
return cls.from_code(self, bucket.code, self.globals, None)
class AnsibleNativeEnvironment(AnsibleEnvironment):
concat = staticmethod(ansible_native_concat) # type: ignore[assignment]
@ -593,8 +627,6 @@ class Templar:
self.jinja2_native = C.DEFAULT_JINJA2_NATIVE
self._cache_pattern = '__ansible_j2_%s.cache'
def _compile_single_var(self, env):
self.SINGLE_VAR = re.compile(r"^%s\s*(\w*)\s*%s$" % (env.variable_start_string, env.variable_end_string))
@ -923,32 +955,11 @@ class Templar:
hint = "Mandatory variable has not been overridden"
return AnsibleUndefined(hint)
@staticmethod
def _make_bucket_name(data, overrides, environment):
name = data
if overrides:
combined = overrides | {}
if environment.newline_sequence != '\n':
combined |= {'newline_sequence': environment.newline_sequence}
override_str = ';'.join(
f'{k}={v}' for k, v in sorted(combined.items())
)
name += f'[{override_str}]'
return name
def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=True, fail_on_undefined=None, overrides=None, disable_lookups=False,
convert_data=False):
if self.jinja2_native and not isinstance(data, string_types):
return data
mtime = None
if (data_source := getattr(data, '_data_source', None)):
mtime = os.stat(data_source).st_mtime
cache_dir = os.path.join(C.DEFAULT_LOCAL_TMP, 'j2cache')
if not os.path.isdir(cache_dir):
os.makedirs(cache_dir, mode=0o700, exist_ok=True)
bcc = FileSystemBytecodeCache(cache_dir, self._cache_pattern)
# For preserving the number of input newlines in the output (used
# later in this method)
data_newlines = _count_newlines_from_end(data)
@ -967,26 +978,8 @@ class Templar:
# Allow users to specify backslashes in playbooks as "\\" instead of as "\\\\".
data = _escape_backslashes(data, myenv)
bucket = bcc.get_bucket(
myenv,
self._make_bucket_name(data, overrides, myenv),
None,
data
)
cache_file = os.path.join(bcc.directory, self._cache_pattern % bucket.key)
if bucket.code and mtime and mtime > os.stat(cache_file).st_mtime:
# Why would this ever be the case, I really hope users aren't modifying things
# in the middle of a run taht could cause this
os.unlink(cache_file)
bucket.code = None
try:
if bucket.code:
bcc.load_bytecode(bucket)
else:
bucket.code = myenv.compile(data)
bcc.set_bucket(bucket)
t = myenv.template_class.from_code(myenv, bucket.code, myenv.globals, None)
t = myenv.from_string(data)
except TemplateSyntaxError as e:
raise AnsibleError("template error while templating string: %s. String: %s" % (to_native(e), to_native(data)), orig_exc=e)
except Exception as e:

Loading…
Cancel
Save