Move resolving parent attributes to FA + cleanup

ci_complete
pull/86158/head
Martin Krizek 3 weeks ago
parent 6bb7bd760f
commit 45d777f6ff

@ -0,0 +1,2 @@
minor_changes:
- internals - simplify ``dep_chain`` handling in playbook objects.

@ -77,8 +77,6 @@ class TaskExecutor:
self._loop_eval_error = None self._loop_eval_error = None
self._task_templar = TemplateEngine(loader=self._loader, variables=self._job_vars) self._task_templar = TemplateEngine(loader=self._loader, variables=self._job_vars)
self._task.squash()
def run(self): def run(self):
""" """
The main executor entrypoint, where we determine if the specified The main executor entrypoint, where we determine if the specified
@ -369,7 +367,6 @@ class TaskExecutor:
if self._task.loop_control and self._task.loop_control.break_when: if self._task.loop_control and self._task.loop_control.break_when:
break_when = self._task.loop_control.get_validated_value( break_when = self._task.loop_control.get_validated_value(
'break_when', 'break_when',
self._task.loop_control.fattributes.get('break_when'),
self._task.loop_control.break_when, self._task.loop_control.break_when,
templar, templar,
) )

@ -90,41 +90,13 @@ class Attribute:
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
self.name = name self.name = name
def __eq__(self, other):
return other.priority == self.priority
def __ne__(self, other):
return other.priority != self.priority
# NB: higher priority numbers sort first # NB: higher priority numbers sort first
# __lt__ is sufficient for sorted() which is our only use case
def __lt__(self, other): def __lt__(self, other):
return other.priority < self.priority return other.priority < self.priority
def __gt__(self, other):
return other.priority > self.priority
def __le__(self, other):
return other.priority <= self.priority
def __ge__(self, other):
return other.priority >= self.priority
def __get__(self, obj: FieldAttributeBase, obj_type=None): def __get__(self, obj: FieldAttributeBase, obj_type=None):
method = f'_get_attr_{self.name}' if (value := getattr(obj, f'_{self.name}', Sentinel)) is Sentinel:
if hasattr(obj, method):
# NOTE this appears to be not used in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
value = self.default value = self.default
if callable(value): if callable(value):
value = value() value = value()
@ -137,17 +109,30 @@ class Attribute:
if self.alias is not None: if self.alias is not None:
setattr(obj, f'_{self.alias}', value) setattr(obj, f'_{self.alias}', value)
# NOTE this appears to be not needed in the codebase,
# leaving it here for test_attr_int_del from
# test/units/playbook/test_base.py to pass.
def __delete__(self, obj):
delattr(obj, f'_{self.name}')
class NonInheritableFieldAttribute(Attribute): class NonInheritableFieldAttribute(Attribute):
... ...
def _get_parent_static_chain(obj):
# NOTE similar code in Taggable
# FIXME this encapsulates the mess caused by not having proper parent chain on playbook objects
parent = getattr(obj, '_parent', None)
while parent:
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(parent, 'statically_loaded', True):
yield parent
parent = getattr(parent, '_parent', None)
if role := getattr(obj, '_role', None):
yield obj._role
if dep_chain := obj.get_dep_chain():
yield from reversed(dep_chain)
yield obj.play
class FieldAttribute(Attribute): class FieldAttribute(Attribute):
def __init__(self, extend=False, prepend=False, **kwargs): def __init__(self, extend=False, prepend=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -156,24 +141,27 @@ class FieldAttribute(Attribute):
self.prepend = prepend self.prepend = prepend
def __get__(self, obj, obj_type=None): def __get__(self, obj, obj_type=None):
if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
try:
value = obj._get_parent_attribute(self.name)
except AttributeError:
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not needed in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel) value = getattr(obj, f'_{self.name}', Sentinel)
if not obj.finalized:
if self.extend:
# NOTE implements the historic behavior of now removed Base._extend_value
all_values = value if value not in (Sentinel, None) else []
_extend = all_values.extend
for parent in _get_parent_static_chain(obj):
if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) not in (Sentinel, None):
if self.prepend:
all_values[:0] = parent_value
else: else:
value = getattr(obj, method)() _extend(parent_value)
else:
value = getattr(obj, f'_{self.name}', Sentinel) # some FAs contain non-hashable values, skip dedup in such a case
# FIXME deal with list of dicts differently?
value = all_values if self.name in ('module_defaults', 'environment') else list(dict.fromkeys(all_values))
elif value is Sentinel:
for parent in _get_parent_static_chain(obj):
if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) is not Sentinel:
value = parent_value
break
if value is Sentinel: if value is Sentinel:
value = self.default value = self.default

@ -5,7 +5,6 @@
from __future__ import annotations from __future__ import annotations
import decimal import decimal
import itertools
import operator import operator
import os import os
@ -30,6 +29,9 @@ from ansible.utils.display import Display
from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name
from ansible._internal._templating._engine import TemplateEngine from ansible._internal._templating._engine import TemplateEngine
if t.TYPE_CHECKING:
from ansible.playbook.role import Role
display = Display() display = Display()
@ -112,13 +114,13 @@ class FieldAttributeBase:
self._origin: Origin | None = None self._origin: Origin | None = None
# other internal params # other internal params
self._validated = False
self._squashed = False
self._finalized = False self._finalized = False
# every object gets a random uuid: # every object gets a random uuid:
self._uuid = get_unique_id() self._uuid = get_unique_id()
self._ds = None
@property @property
def finalized(self): def finalized(self):
return self._finalized return self._finalized
@ -130,9 +132,7 @@ class FieldAttributeBase:
display.debug("%s- %s (%s, id=%s)" % (" " * depth, self.__class__.__name__, self, id(self))) display.debug("%s- %s (%s, id=%s)" % (" " * depth, self.__class__.__name__, self, id(self)))
if hasattr(self, '_parent') and self._parent: if hasattr(self, '_parent') and self._parent:
self._parent.dump_me(depth + 2) self._parent.dump_me(depth + 2)
dep_chain = self._parent.get_dep_chain() for dep in self._parent.get_dep_chain():
if dep_chain:
for dep in dep_chain:
dep.dump_me(depth + 2) dep.dump_me(depth + 2)
if hasattr(self, '_play') and self._play: if hasattr(self, '_play') and self._play:
self._play.dump_me(depth + 2) self._play.dump_me(depth + 2)
@ -148,7 +148,7 @@ class FieldAttributeBase:
raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds) raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds)
# cache the datastructure internally # cache the datastructure internally
setattr(self, '_ds', ds) self._ds = ds
# the variable manager class is used to manage and merge variables # the variable manager class is used to manage and merge variables
# down to a single dictionary for reference in templating, etc. # down to a single dictionary for reference in templating, etc.
@ -185,10 +185,7 @@ class FieldAttributeBase:
return self return self
def get_ds(self): def get_ds(self):
try: return self._ds
return getattr(self, '_ds')
except AttributeError:
return None
def get_loader(self): def get_loader(self):
return self._loader return self._loader
@ -218,26 +215,14 @@ class FieldAttributeBase:
if key not in valid_attrs: if key not in valid_attrs:
raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=key) raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=key)
def validate(self, all_vars=None): def validate(self):
""" validation that is done at parse time, not load time """ """ validation that is done at parse time, not load time """
if not self._validated: for name, attribute in self.fattributes.items():
# walk all fields in the object if (value := getattr(self, f'_{name}', Sentinel)) is Sentinel:
for (name, attribute) in self.fattributes.items(): continue
# run validator only if present
method = getattr(self, '_validate_%s' % name, None)
if method:
method(attribute, name, getattr(self, name))
else:
# and make sure the attribute is of the type it should be
value = getattr(self, f'_{name}', Sentinel)
if value is not None:
if attribute.isa == 'string' and isinstance(value, (list, dict)):
raise AnsibleParserError(
"The field '%s' is supposed to be a string type,"
" however the incoming data structure is a %s" % (name, type(value)), obj=self.get_ds()
)
self._validated = True if method := getattr(self, f'_validate_{name}', None):
method(attribute, name, value)
def _load_module_defaults(self, name, value): def _load_module_defaults(self, name, value):
if value is None: if value is None:
@ -297,8 +282,8 @@ class FieldAttributeBase:
@property @property
def play(self): def play(self):
if hasattr(self, '_play'): if _play := getattr(self, '_play', None):
play = self._play play = _play
elif hasattr(self, '_parent') and hasattr(self._parent, '_play'): elif hasattr(self, '_parent') and hasattr(self._parent, '_play'):
play = self._parent._play play = self._parent._play
else: else:
@ -410,17 +395,6 @@ class FieldAttributeBase:
raise AnsibleParserError("Could not resolve action %s in module_defaults" % action_name) raise AnsibleParserError("Could not resolve action %s in module_defaults" % action_name)
display.vvvvv("Could not resolve action %s in module_defaults" % action_name) display.vvvvv("Could not resolve action %s in module_defaults" % action_name)
def squash(self):
"""
Evaluates all attributes and sets them to the evaluated version,
so that all future accesses of attributes do not need to evaluate
parent attributes.
"""
if not self._squashed:
for name in self.fattributes:
setattr(self, name, getattr(self, name))
self._squashed = True
def copy(self): def copy(self):
""" """
Create a copy of this object and return it. Create a copy of this object and return it.
@ -437,17 +411,17 @@ class FieldAttributeBase:
new_me._loader = self._loader new_me._loader = self._loader
new_me._variable_manager = self._variable_manager new_me._variable_manager = self._variable_manager
new_me._origin = self._origin new_me._origin = self._origin
new_me._validated = self._validated
new_me._finalized = self._finalized new_me._finalized = self._finalized
new_me._uuid = self._uuid new_me._uuid = self._uuid
# if the ds value was set on the object, copy it to the new copy too # if the ds value was set on the object, copy it to the new copy too
if hasattr(self, '_ds'): if _ds := self.get_ds():
new_me._ds = self._ds new_me._ds = _ds
return new_me return new_me
def get_validated_value(self, name, attribute, value, templar): def get_validated_value(self, name: str, value: object, templar: TemplateEngine):
attribute: Attribute = self.fattributes[name]
try: try:
return self._get_validated_value(name, attribute, value, templar) return self._get_validated_value(name, attribute, value, templar)
except (TypeError, ValueError): except (TypeError, ValueError):
@ -455,6 +429,13 @@ class FieldAttributeBase:
def _get_validated_value(self, name, attribute, value, templar): def _get_validated_value(self, name, attribute, value, templar):
if attribute.isa == 'string': if attribute.isa == 'string':
if isinstance(value, (list, dict)):
# NOTE historically this check has been in validate()
raise AnsibleParserError(
message=f"The field {name!r} is supposed to be a string type, "
f"however the incoming data structure is a {type(value)}",
obj=self.get_ds(),
)
value = to_text(value) value = to_text(value)
elif attribute.isa == 'int': elif attribute.isa == 'int':
if not isinstance(value, int): if not isinstance(value, int):
@ -509,20 +490,8 @@ class FieldAttributeBase:
def set_to_context(self, name: str) -> t.Any: def set_to_context(self, name: str) -> t.Any:
""" set to parent inherited value or Sentinel as appropriate""" """ set to parent inherited value or Sentinel as appropriate"""
setattr(self, name, Sentinel)
attribute = self.fattributes[name] return getattr(self, name, Sentinel)
if isinstance(attribute, NonInheritableFieldAttribute):
# setting to sentinel will trigger 'default/default()' on getter
value = Sentinel
else:
try:
value = self._get_parent_attribute(name, omit=True)
except AttributeError:
# mostly playcontext as only tasks/handlers/blocks really resolve parent
value = Sentinel
setattr(self, name, value)
return value
def post_validate(self, templar): def post_validate(self, templar):
""" """
@ -541,33 +510,23 @@ class FieldAttributeBase:
self._finalized = True self._finalized = True
def post_validate_attribute(self, name: str, *, templar: TemplateEngine): def post_validate_attribute(self, name: str, *, templar: TemplateEngine):
attribute: FieldAttribute = self.fattributes[name] attribute: Attribute = self.fattributes[name]
# DTFIX-FUTURE: this can probably be used in many getattr cases below, but the value may be out-of-date in some cases
original_value = getattr(self, name) # we save this original (likely Origin-tagged) value to pass as `obj` for errors original_value = getattr(self, name) # we save this original (likely Origin-tagged) value to pass as `obj` for errors
if attribute.static: if attribute.static:
value = getattr(self, name)
# we don't template 'vars' but allow template as values for later use # we don't template 'vars' but allow template as values for later use
if name not in ('vars',) and templar.is_template(value): if name not in ('vars',) and templar.is_template(original_value):
display.warning('"%s" is not templatable, but we found: %s, ' display.warning('"%s" is not templatable, but we found: %s, '
'it will not be templated and will be used "as is".' % (name, value)) 'it will not be templated and will be used "as is".' % (name, original_value))
return Sentinel return Sentinel
if getattr(self, name) is None: if original_value is None:
if not attribute.required: if not attribute.required:
return Sentinel return Sentinel
raise AnsibleFieldAttributeError(f'The field {name!r} is required but was not set.', obj=self.get_ds()) raise AnsibleFieldAttributeError(f'The field {name!r} is required but was not set.', obj=self.get_ds())
from .role_include import IncludeRole
if not attribute.always_post_validate and isinstance(self, IncludeRole) and self.statically_loaded: # import_role
# normal field attributes should not go through post validation on import_role/import_tasks
# only import_role is checked here because import_tasks never reaches this point
return Sentinel
# Skip post validation unless always_post_validate is True, or the object requires post validation. # Skip post validation unless always_post_validate is True, or the object requires post validation.
if not attribute.always_post_validate and not self._post_validate_object: if not attribute.always_post_validate and not self._post_validate_object:
# Intermediate objects like Play() won't have their fields validated by # Intermediate objects like Play() won't have their fields validated by
@ -581,13 +540,13 @@ class FieldAttributeBase:
method = getattr(self, '_post_validate_%s' % name, None) method = getattr(self, '_post_validate_%s' % name, None)
if method: if method:
value = method(attribute, getattr(self, name), templar) value = method(attribute, original_value, templar)
elif attribute.isa == 'class': elif attribute.isa == 'class':
value = getattr(self, name) value = original_value
else: else:
try: try:
# if the attribute contains a variable, template it now # if the attribute contains a variable, template it now
value = templar.template(getattr(self, name)) value = templar.template(original_value)
except AnsibleValueOmittedError: except AnsibleValueOmittedError:
# If this evaluated to the omit value, set the value back to inherited by context # If this evaluated to the omit value, set the value back to inherited by context
# or default specified in the FieldAttribute and move on # or default specified in the FieldAttribute and move on
@ -598,7 +557,7 @@ class FieldAttributeBase:
# and make sure the attribute is of the type it should be # and make sure the attribute is of the type it should be
if value is not None: if value is not None:
value = self.get_validated_value(name, attribute, value, templar) value = self.get_validated_value(name, value, templar)
# returning the value results in assigning the massaged value back to the attribute field # returning the value results in assigning the massaged value back to the attribute field
return value return value
@ -627,31 +586,6 @@ class FieldAttributeBase:
except TypeError as ex: except TypeError as ex:
raise AnsibleParserError(f"Invalid variable name in vars specified for {self.__class__.__name__}.", obj=ds) from ex raise AnsibleParserError(f"Invalid variable name in vars specified for {self.__class__.__name__}.", obj=ds) from ex
def _extend_value(self, value, new_value, prepend=False):
"""
Will extend the value given with new_value (and will turn both
into lists if they are not so already). The values are run through
a set to remove duplicate values.
"""
if not isinstance(value, list):
value = [value]
if not isinstance(new_value, list):
new_value = [new_value]
# Due to where _extend_value may run for some attributes
# it is possible to end up with Sentinel in the list of values
# ensure we strip them
value = [v for v in value if v is not Sentinel]
new_value = [v for v in new_value if v is not Sentinel]
if prepend:
combined = new_value + value
else:
combined = value + new_value
return [i for i, dummy in itertools.groupby(combined) if i is not None]
def dump_attrs(self): def dump_attrs(self):
""" """
Dumps all attributes to a dictionary Dumps all attributes to a dictionary
@ -719,8 +653,9 @@ class Base(FieldAttributeBase):
become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags')) become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe')) become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
# used to hold sudo/su stuff def _validate_environment(self, attr, name, value):
DEPRECATED_ATTRIBUTES = [] # type: list[str] if not isinstance(value, list):
setattr(self, name, [value])
def update_result_no_log(self, templar: TemplateEngine, result: dict[str, t.Any]) -> None: def update_result_no_log(self, templar: TemplateEngine, result: dict[str, t.Any]) -> None:
"""Set the post-validated no_log value for the result, falling back to a default on validation/templating failure with a warning.""" """Set the post-validated no_log value for the result, falling back to a default on validation/templating failure with a warning."""
@ -761,12 +696,11 @@ class Base(FieldAttributeBase):
return path return path
def get_dep_chain(self): def get_dep_chain(self) -> list[Role]:
if role := getattr(self, '_role', None):
if hasattr(self, '_parent') and self._parent: return role.get_dep_chain() + [role]
return self._parent.get_dep_chain()
else: else:
return None return []
def get_search_path(self): def get_search_path(self):
""" """
@ -775,9 +709,8 @@ class Base(FieldAttributeBase):
""" """
path_stack = [] path_stack = []
dep_chain = self.get_dep_chain()
# inside role: add the dependency chain from current to dependent # inside role: add the dependency chain from current to dependent
if dep_chain: if dep_chain := self.get_dep_chain():
path_stack.extend(reversed([x._role_path for x in dep_chain if hasattr(x, '_role_path')])) path_stack.extend(reversed([x._role_path for x in dep_chain if hasattr(x, '_role_path')]))
# add path of task itself, unless it is already in the list # add path of task itself, unless it is already in the list

@ -18,7 +18,6 @@
from __future__ import annotations from __future__ import annotations
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.module_utils.common.sentinel import Sentinel
from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
@ -40,13 +39,11 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
# similar to the 'else' clause for exceptions # similar to the 'else' clause for exceptions
# otherwise = FieldAttribute(isa='list') # otherwise = FieldAttribute(isa='list')
def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False): def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False):
self._play = play self._play = play
self._role = role self._role = role
self._parent = None self._parent = None
self._dep_chain = None
self._use_handlers = use_handlers self._use_handlers = use_handlers
self._implicit = implicit
if task_include: if task_include:
self._parent = task_include self._parent = task_include
@ -77,25 +74,18 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
if self._parent: if self._parent:
all_vars |= self._parent.get_vars() all_vars |= self._parent.get_vars()
all_vars |= self.vars.copy() all_vars |= self.vars
return all_vars return all_vars
@staticmethod @staticmethod
def load(data, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None): def load(data, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None):
implicit = not Block.is_block(data) b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers)
b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers, implicit=implicit)
return b.load_data(data, variable_manager=variable_manager, loader=loader) return b.load_data(data, variable_manager=variable_manager, loader=loader)
@staticmethod @staticmethod
def is_block(ds): def is_block(ds):
is_block = False return isinstance(ds, dict) and 'block' in ds
if isinstance(ds, dict):
for attr in ('block', 'rescue', 'always'):
if attr in ds:
is_block = True
break
return is_block
def preprocess_data(self, ds): def preprocess_data(self, ds):
""" """
@ -111,8 +101,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
return super(Block, self).preprocess_data(ds) return super(Block, self).preprocess_data(ds)
# FIXME: these do nothing but augment the exception message; DRY and nuke
def _load_block(self, attr, ds): def _load_block(self, attr, ds):
try: try:
return load_list_of_tasks( return load_list_of_tasks(
@ -126,37 +114,9 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
use_handlers=self._use_handlers, use_handlers=self._use_handlers,
) )
except AssertionError as ex: except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading a block", obj=self._ds) from ex raise AnsibleParserError(f"A malformed block was encountered while loading a {attr}", obj=self._ds) from ex
def _load_rescue(self, attr, ds): _load_rescue = _load_always = _load_block
try:
return load_list_of_tasks(
ds,
play=self._play,
block=self,
role=self._role,
task_include=None,
variable_manager=self._variable_manager,
loader=self._loader,
use_handlers=self._use_handlers,
)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading rescue.", obj=self._ds) from ex
def _load_always(self, attr, ds):
try:
return load_list_of_tasks(
ds,
play=self._play,
block=self,
role=self._role,
task_include=None,
variable_manager=self._variable_manager,
loader=self._loader,
use_handlers=self._use_handlers,
)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading always", obj=self._ds) from ex
def _validate_always(self, attr, name, value): def _validate_always(self, attr, name, value):
if value and not self.block: if value and not self.block:
@ -164,15 +124,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
_validate_rescue = _validate_always _validate_rescue = _validate_always
def get_dep_chain(self):
if self._dep_chain is None:
if self._parent:
return self._parent.get_dep_chain()
else:
return None
else:
return self._dep_chain[:]
def copy(self, exclude_parent=False, exclude_tasks=False): def copy(self, exclude_parent=False, exclude_tasks=False):
def _dupe_task_list(task_list, new_block): def _dupe_task_list(task_list, new_block):
new_task_list = [] new_task_list = []
@ -199,9 +150,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
new_me._play = self._play new_me._play = self._play
new_me._use_handlers = self._use_handlers new_me._use_handlers = self._use_handlers
if self._dep_chain is not None:
new_me._dep_chain = self._dep_chain[:]
new_me._parent = None new_me._parent = None
if self._parent and not exclude_parent: if self._parent and not exclude_parent:
new_me._parent = self._parent.copy(exclude_tasks=True) new_me._parent = self._parent.copy(exclude_tasks=True)
@ -211,11 +159,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
new_me.rescue = _dupe_task_list(self.rescue or [], new_me) new_me.rescue = _dupe_task_list(self.rescue or [], new_me)
new_me.always = _dupe_task_list(self.always or [], new_me) new_me.always = _dupe_task_list(self.always or [], new_me)
new_me._role = None
if self._role:
new_me._role = self._role new_me._role = self._role
new_me.validate()
return new_me return new_me
def set_loader(self, loader): def set_loader(self, loader):
@ -225,84 +170,9 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
elif self._role: elif self._role:
self._role.set_loader(loader) self._role.set_loader(loader)
dep_chain = self.get_dep_chain() for dep in self.get_dep_chain():
if dep_chain:
for dep in dep_chain:
dep.set_loader(loader) dep.set_loader(loader)
def _get_parent_attribute(self, attr, omit=False):
"""
Generic logic to get the attribute or parent attribute for a block value.
"""
fattr = self.fattributes[attr]
extend = fattr.extend
prepend = fattr.prepend
try:
# omit self, and only get parent values
if omit:
value = Sentinel
else:
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
_parent = self._parent
else:
_parent = self._parent._parent
if _parent and (value is Sentinel or extend):
try:
if getattr(_parent, 'statically_loaded', True):
if hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except AttributeError:
pass
if self._role and (value is Sentinel or extend):
try:
parent_value = getattr(self._role, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
dep_chain = self.get_dep_chain()
if dep_chain and (value is Sentinel or extend):
dep_chain.reverse()
for dep in dep_chain:
dep_value = getattr(dep, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, dep_value, prepend)
else:
value = dep_value
if value is not Sentinel and not extend:
break
except AttributeError:
pass
if self._play and (value is Sentinel or extend):
try:
play_value = getattr(self._play, f'_{attr}', Sentinel)
if play_value is not Sentinel:
if extend:
value = self._extend_value(value, play_value, prepend)
else:
value = play_value
except AttributeError:
pass
except KeyError:
pass
return value
def filter_tagged_tasks(self, all_vars): def filter_tagged_tasks(self, all_vars):
""" """
Creates a new block, with task lists filtered based on the tags. Creates a new block, with task lists filtered based on the tags.
@ -348,7 +218,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
return evaluate_block(self) return evaluate_block(self)
def has_tasks(self): def has_tasks(self):
return len(self.block) > 0 or len(self.rescue) > 0 or len(self.always) > 0 return bool(self.block or self.rescue or self.always)
def get_include_params(self): def get_include_params(self):
if self._parent: if self._parent:
@ -356,21 +226,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
else: else:
return dict() return dict()
def all_parents_static(self):
"""
Determine if all of the parents of this block were statically loaded
or not. Since Task/TaskInclude objects may be in the chain, they simply
call their parents all_parents_static() method. Only Block objects in
the chain check the statically_loaded value of the parent.
"""
from ansible.playbook.task_include import TaskInclude
if self._parent:
if isinstance(self._parent, TaskInclude) and not self._parent.statically_loaded:
return False
return self._parent.all_parents_static()
return True
def get_first_parent_include(self): def get_first_parent_include(self):
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
if self._parent: if self._parent:

@ -5,9 +5,6 @@ from __future__ import annotations
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute
from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.display import Display
display = Display()
def _ensure_default_collection(collection_list=None): def _ensure_default_collection(collection_list=None):
@ -36,7 +33,7 @@ class CollectionSearch:
def _load_collections(self, attr, ds): def _load_collections(self, attr, ds):
# We are always a mixin with Base, so we can validate this untemplated # We are always a mixin with Base, so we can validate this untemplated
# field early on to guarantee we are dealing with a list. # field early on to guarantee we are dealing with a list.
ds = self.get_validated_value('collections', self.fattributes.get('collections'), ds, None) ds = self.get_validated_value('collections', ds, None)
# this will only be called if someone specified a value; call the shared value # this will only be called if someone specified a value; call the shared value
_ensure_default_collection(collection_list=ds) _ensure_default_collection(collection_list=ds)

@ -18,9 +18,6 @@
from __future__ import annotations from __future__ import annotations
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute
from ansible.utils.display import Display
display = Display()
class Conditional: class Conditional:
@ -31,9 +28,6 @@ class Conditional:
when = FieldAttribute(isa='list', default=list, extend=True, prepend=True) when = FieldAttribute(isa='list', default=list, extend=True, prepend=True)
def __init__(self, *args, **kwargs):
super().__init__()
def _validate_when(self, attr, name, value): def _validate_when(self, attr, name, value):
if not isinstance(value, list): if not isinstance(value, list):
setattr(self, name, [value]) setattr(self, name, [value])

@ -38,7 +38,7 @@ class Handler(Task):
return "HANDLER: %s" % self.get_name() return "HANDLER: %s" % self.get_name()
def _validate_listen(self, attr, name, value): def _validate_listen(self, attr, name, value):
new_value = self.get_validated_value(name, attr, value, None) new_value = self.get_validated_value(name, value, None)
if self._role is not None: if self._role is not None:
for listener in new_value.copy(): for listener in new_value.copy():
new_value.extend([ new_value.extend([

@ -17,7 +17,6 @@
from __future__ import annotations from __future__ import annotations
# from ansible.inventory.host import Host
from ansible.playbook.handler import Handler from ansible.playbook.handler import Handler
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude

@ -291,7 +291,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h
def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None, loader=None, collection_search_list=None): def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None, loader=None, collection_search_list=None):
""" """
Loads and returns a list of RoleInclude objects from the ds list of role definitions Loads and returns a list of RoleDefinition objects from the ds list of role definitions
:param ds: list of roles to load :param ds: list of roles to load
:param play: calling Play object :param play: calling Play object
:param current_role_path: path of the owning role, if any :param current_role_path: path of the owning role, if any
@ -301,14 +301,14 @@ def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None,
:return: :return:
""" """
# we import here to prevent a circular dependency with imports # we import here to prevent a circular dependency with imports
from ansible.playbook.role.include import RoleInclude from ansible.playbook.role.definition import RoleDefinition
if not isinstance(ds, list): if not isinstance(ds, list):
raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds))) raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds)))
roles = [] roles = []
for role_def in ds: for role_def in ds:
i = RoleInclude.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager, i = RoleDefinition.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager,
loader=loader, collection_list=collection_search_list) loader=loader, collection_list=collection_search_list)
roles.append(i) roles.append(i)

@ -31,9 +31,6 @@ class LoopControl(FieldAttributeBase):
extended_allitems = NonInheritableFieldAttribute(isa='bool', default=True, always_post_validate=True) extended_allitems = NonInheritableFieldAttribute(isa='bool', default=True, always_post_validate=True)
break_when = NonInheritableFieldAttribute(isa='list', default=list) break_when = NonInheritableFieldAttribute(isa='list', default=list)
def __init__(self):
super(LoopControl, self).__init__()
@staticmethod @staticmethod
def load(data, variable_manager=None, loader=None): def load(data, variable_manager=None, loader=None):
t = LoopControl() t = LoopControl()

@ -35,12 +35,9 @@ from ansible.playbook.role import Role
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
from ansible.parsing.vault import EncryptedString from ansible.parsing.vault import EncryptedString
from ansible.utils.display import Display
from ansible._internal._templating._engine import TemplateEngine as _TE from ansible._internal._templating._engine import TemplateEngine as _TE
display = Display()
__all__ = ['Play'] __all__ = ['Play']
@ -73,7 +70,7 @@ class Play(Base, Taggable, CollectionSearch):
validate_argspec = NonInheritableFieldAttribute(isa='string', always_post_validate=True) validate_argspec = NonInheritableFieldAttribute(isa='string', always_post_validate=True)
# Role Attributes # Role Attributes
roles = NonInheritableFieldAttribute(isa='list', default=list, priority=90) roles = NonInheritableFieldAttribute(isa='list', default=list, priority=-1)
# Block (Task) Lists Attributes # Block (Task) Lists Attributes
handlers = NonInheritableFieldAttribute(isa='list', default=list, priority=-1) handlers = NonInheritableFieldAttribute(isa='list', default=list, priority=-1)
@ -156,27 +153,11 @@ class Play(Base, Taggable, CollectionSearch):
""" """
Adjusts play datastructure to cleanup old/legacy items Adjusts play datastructure to cleanup old/legacy items
""" """
if not isinstance(ds, dict): if not isinstance(ds, dict):
raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds))) raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds)))
# The use of 'user' in the Play datastructure was deprecated to
# line up with the same change for Tasks, due to the fact that
# 'user' conflicted with the user module.
if 'user' in ds:
# this should never happen, but error out with a helpful message
# to the user if it does...
if 'remote_user' in ds:
raise AnsibleParserError("both 'user' and 'remote_user' are set for this play. "
"The use of 'user' is deprecated, and should be removed", obj=ds)
ds['remote_user'] = ds['user']
del ds['user']
return super(Play, self).preprocess_data(ds) return super(Play, self).preprocess_data(ds)
# DTFIX-FUTURE: these do nothing but augment the exception message; DRY and nuke
def _load_tasks(self, attr, ds): def _load_tasks(self, attr, ds):
""" """
Loads a list of blocks from a list which may be mixed tasks/blocks. Loads a list of blocks from a list which may be mixed tasks/blocks.
@ -185,27 +166,9 @@ class Play(Base, Taggable, CollectionSearch):
try: try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex: except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading tasks.", obj=self._ds) from ex raise AnsibleParserError(f"A malformed block was encountered while loading {attr}.", obj=self._ds) from ex
def _load_pre_tasks(self, attr, ds): _load_pre_tasks = _load_post_tasks = _load_tasks
"""
Loads a list of blocks from a list which may be mixed tasks/blocks.
Bare tasks outside of a block are given an implicit block.
"""
try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading pre_tasks.", obj=self._ds) from ex
def _load_post_tasks(self, attr, ds):
"""
Loads a list of blocks from a list which may be mixed tasks/blocks.
Bare tasks outside of a block are given an implicit block.
"""
try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading post_tasks.", obj=self._ds) from ex
def _load_handlers(self, attr, ds): def _load_handlers(self, attr, ds):
""" """
@ -213,17 +176,13 @@ class Play(Base, Taggable, CollectionSearch):
Bare handlers outside of a block are given an implicit block. Bare handlers outside of a block are given an implicit block.
""" """
try: try:
return self._extend_value( return load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader) + self.handlers
self.handlers,
load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader),
prepend=True
)
except AssertionError as ex: except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading handlers.", obj=self._ds) from ex raise AnsibleParserError("A malformed block was encountered while loading handlers.", obj=self._ds) from ex
def _load_roles(self, attr, ds): def _load_roles(self, attr, ds):
""" """
Loads and returns a list of RoleInclude objects from the datastructure Loads and returns a list of RoleDefinition objects from the datastructure
list of role definitions and creates the Role from those objects list of role definitions and creates the Role from those objects
""" """
@ -382,21 +341,6 @@ class Play(Base, Taggable, CollectionSearch):
return [self.vars_files] return [self.vars_files]
return self.vars_files return self.vars_files
def get_handlers(self):
return self.handlers[:]
def get_roles(self):
return self.roles[:]
def get_tasks(self):
tasklist = []
for task in self.pre_tasks + self.tasks + self.post_tasks:
if isinstance(task, Block):
tasklist.append(task.block + task.rescue + task.always)
else:
tasklist.append(task)
return tasklist
def copy(self): def copy(self):
new_me = super(Play, self).copy() new_me = super(Play, self).copy()
new_me.role_cache = self.role_cache.copy() new_me.role_cache = self.role_cache.copy()

@ -41,12 +41,6 @@ from ansible.utils.display import Display
from ansible.utils.path import is_subpath from ansible.utils.path import is_subpath
from ansible.utils.vars import combine_vars from ansible.utils.vars import combine_vars
# NOTE: This import is only needed for the type-checking in __init__. While there's an alternative
# available by using forward references this seems not to work well with commonly used IDEs.
# Therefore the TYPE_CHECKING hack seems to be a more universal approach, even if not being very elegant.
# References:
# * https://stackoverflow.com/q/39740632/199513
# * https://peps.python.org/pep-0484/#forward-references
if _t.TYPE_CHECKING: if _t.TYPE_CHECKING:
from ansible.playbook.block import Block from ansible.playbook.block import Block
from ansible.playbook.play import Play from ansible.playbook.play import Play
@ -155,6 +149,8 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
self._completed: dict[str, bool] = dict() self._completed: dict[str, bool] = dict()
self._should_validate: bool = validate self._should_validate: bool = validate
self._dep_chain: list[Role] | None = None
if from_files is None: if from_files is None:
from_files = {} from_files = {}
self._from_files: dict[str, list[str]] = from_files self._from_files: dict[str, list[str]] = from_files
@ -206,7 +202,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
return self._get_hash_dict() == other._get_hash_dict() return self._get_hash_dict() == other._get_hash_dict()
@staticmethod @staticmethod
def load(role_include, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True): def load(role_definition, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True):
if from_files is None: if from_files is None:
from_files = {} from_files = {}
try: try:
@ -215,7 +211,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
# that role?) # that role?)
# see https://github.com/ansible/ansible/issues/61527 # see https://github.com/ansible/ansible/issues/61527
r = Role(play=play, from_files=from_files, from_include=from_include, validate=validate, public=public, static=static, rescuable=rescuable) r = Role(play=play, from_files=from_files, from_include=from_include, validate=validate, public=public, static=static, rescuable=rescuable)
r._load_role_data(role_include, parent_role=parent_role) r._load_role_data(role_definition, parent_role=parent_role)
role_path = r.get_role_path() role_path = r.get_role_path()
if role_path not in play.role_cache: if role_path not in play.role_cache:
@ -230,23 +226,23 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
except RecursionError as ex: except RecursionError as ex:
raise AnsibleError("A recursion loop was detected with the roles specified. Make sure child roles do not have dependencies on parent roles", raise AnsibleError("A recursion loop was detected with the roles specified. Make sure child roles do not have dependencies on parent roles",
obj=role_include._ds) from ex obj=role_definition._ds) from ex
def _load_role_data(self, role_include, parent_role=None): def _load_role_data(self, role_definition, parent_role=None):
self._role_name = role_include.role self._role_name = role_definition.role
self._role_path = role_include.get_role_path() self._role_path = role_definition.get_role_path()
self._role_collection = role_include._role_collection self._role_collection = role_definition._role_collection
self._role_params = role_include.get_role_params() self._role_params = role_definition.get_role_params()
self._variable_manager = role_include.get_variable_manager() self._variable_manager = role_definition.get_variable_manager()
self._loader = role_include.get_loader() self._loader = role_definition.get_loader()
if parent_role: if parent_role:
self.add_parent(parent_role) self.add_parent(parent_role)
# copy over all field attributes from the RoleInclude # copy over all field attributes from the RoleDefinition
# update self._attr directly, to avoid squashing # update self._attr directly, to avoid squashing
for attr_name in self.fattributes: for attr_name in self.fattributes:
setattr(self, f'_{attr_name}', getattr(role_include, f'_{attr_name}', Sentinel)) setattr(self, f'_{attr_name}', getattr(role_definition, f'_{attr_name}', Sentinel))
# vars and default vars are regular dictionaries # vars and default vars are regular dictionaries
self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True) self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True)
@ -466,14 +462,12 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
""" """
deps = [] deps = []
for role_include in self._metadata.dependencies: for role_definition in self._metadata.dependencies:
r = Role.load(role_include, play=self._play, parent_role=self, static=self.static) r = Role.load(role_definition, play=self._play, parent_role=self, static=self.static)
deps.append(r) deps.append(r)
return deps return deps
# other functions
def add_parent(self, parent_role): def add_parent(self, parent_role):
""" adds a role to the list of this roles parents """ """ adds a role to the list of this roles parents """
if not isinstance(parent_role, Role): if not isinstance(parent_role, Role):
@ -485,12 +479,15 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
def get_parents(self): def get_parents(self):
return self._parents return self._parents
def get_dep_chain(self): def get_dep_chain(self) -> list[Role]:
"""Returns a copy of the parent chain list."""
if self._dep_chain is None:
dep_chain = [] dep_chain = []
for parent in self._parents: for parent in self._parents:
dep_chain.extend(parent.get_dep_chain()) dep_chain.extend(parent.get_dep_chain())
dep_chain.append(parent) dep_chain.append(parent)
return dep_chain self._dep_chain = dep_chain
return self._dep_chain[:]
def get_default_vars(self, dep_chain=None): def get_default_vars(self, dep_chain=None):
dep_chain = [] if dep_chain is None else dep_chain dep_chain = [] if dep_chain is None else dep_chain
@ -529,8 +526,6 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
def get_vars(self, dep_chain=None, include_params=True, only_exports=False): def get_vars(self, dep_chain=None, include_params=True, only_exports=False):
dep_chain = [] if dep_chain is None else dep_chain dep_chain = [] if dep_chain is None else dep_chain
all_vars = {}
# get role_vars: from parent objects # get role_vars: from parent objects
# TODO: is this right precedence for inherited role_vars? # TODO: is this right precedence for inherited role_vars?
all_vars = self.get_inherited_vars(dep_chain, only_exports=only_exports) all_vars = self.get_inherited_vars(dep_chain, only_exports=only_exports)
@ -586,23 +581,17 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
# #
# ``get_handler_blocks`` may be called when handling ``import_role`` during parsing # ``get_handler_blocks`` may be called when handling ``import_role`` during parsing
# as well as with ``Play.compile_roles_handlers`` from ``TaskExecutor`` # as well as with ``Play.compile_roles_handlers`` from ``TaskExecutor``
# FIXME deprecate unused dep_chain parameter
if self._compiled_handler_blocks: if self._compiled_handler_blocks:
return self._compiled_handler_blocks return self._compiled_handler_blocks
self._compiled_handler_blocks = block_list = [] self._compiled_handler_blocks = block_list = []
# update the dependency chain here
if dep_chain is None:
dep_chain = []
new_dep_chain = dep_chain + [self]
for dep in self.get_direct_dependencies(): for dep in self.get_direct_dependencies():
dep_blocks = dep.get_handler_blocks(play=play, dep_chain=new_dep_chain) block_list.extend(dep.get_handler_blocks(play=play))
block_list.extend(dep_blocks)
for task_block in self._handler_blocks: for task_block in self._handler_blocks:
new_task_block = task_block.copy() new_task_block = task_block.copy()
new_task_block._dep_chain = new_dep_chain
new_task_block._play = play new_task_block._play = play
block_list.append(new_task_block) block_list.append(new_task_block)
@ -626,24 +615,18 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
with each task, so tasks know by which route they were found, and with each task, so tasks know by which route they were found, and
can correctly take their parent's tags/conditionals into account. can correctly take their parent's tags/conditionals into account.
""" """
# FIXME deprecate unused dep_chain parameter
from ansible.playbook.block import Block from ansible.playbook.block import Block
from ansible.playbook.task import Task from ansible.playbook.task import Task
block_list = [] block_list = []
# update the dependency chain here for dep in self.get_direct_dependencies():
if dep_chain is None: dep_blocks = dep.compile(play=play)
dep_chain = []
new_dep_chain = dep_chain + [self]
deps = self.get_direct_dependencies()
for dep in deps:
dep_blocks = dep.compile(play=play, dep_chain=new_dep_chain)
block_list.extend(dep_blocks) block_list.extend(dep_blocks)
for task_block in self._task_blocks: for task_block in self._task_blocks:
new_task_block = task_block.copy() new_task_block = task_block.copy()
new_task_block._dep_chain = new_dep_chain
new_task_block._play = play new_task_block._play = play
block_list.append(new_task_block) block_list.append(new_task_block)

@ -20,12 +20,13 @@ from __future__ import annotations
import os import os
from ansible import constants as C from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleAssertionError from ansible.errors import AnsibleError, AnsibleAssertionError, AnsibleParserError
from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.module_utils._internal._datatag import AnsibleTagHelper
from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.delegatable import Delegatable
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
from ansible._internal._templating._engine import TemplateEngine from ansible._internal._templating._engine import TemplateEngine
from ansible.utils.collection_loader import AnsibleCollectionRef from ansible.utils.collection_loader import AnsibleCollectionRef
@ -38,13 +39,12 @@ __all__ = ['RoleDefinition']
display = Display() display = Display()
class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): class RoleDefinition(Base, Conditional, Taggable, Delegatable, CollectionSearch):
role = NonInheritableFieldAttribute(isa='string') role = NonInheritableFieldAttribute(isa='string')
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None): def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super().__init__()
super(RoleDefinition, self).__init__()
self._play = play self._play = play
self._variable_manager = variable_manager self._variable_manager = variable_manager
@ -56,12 +56,16 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
self._role_params = dict() self._role_params = dict()
self._collection_list = collection_list self._collection_list = collection_list
# def __repr__(self):
# return 'ROLEDEF: ' + self._attributes.get('role', '<no name set>')
@staticmethod @staticmethod
def load(data, variable_manager=None, loader=None): def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None):
raise AnsibleError("not implemented") if not (isinstance(data, str) or isinstance(data, dict)):
raise AnsibleParserError("Invalid role definition.", obj=data)
if isinstance(data, str) and ',' in data:
raise AnsibleError("Invalid old style role requirement: %s" % data)
rd = RoleDefinition(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list)
return rd.load_data(data, variable_manager=variable_manager, loader=loader)
def preprocess_data(self, ds): def preprocess_data(self, ds):
# role names that are simply numbers can be parsed by PyYAML # role names that are simply numbers can be parsed by PyYAML

@ -1,49 +0,0 @@
# (c) 2014 Michael DeHaan, <michael@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
from ansible.errors import AnsibleError, AnsibleParserError
from ansible.playbook.delegatable import Delegatable
from ansible.playbook.role.definition import RoleDefinition
__all__ = ['RoleInclude']
class RoleInclude(RoleDefinition, Delegatable):
"""
A derivative of RoleDefinition, used by playbook code when a role
is included for execution in a play.
"""
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager,
loader=loader, collection_list=collection_list)
@staticmethod
def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None):
if not (isinstance(data, str) or isinstance(data, dict)):
raise AnsibleParserError("Invalid role definition.", obj=data)
if isinstance(data, str) and ',' in data:
raise AnsibleError("Invalid old style role requirement: %s" % data)
ri = RoleInclude(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list)
return ri.load_data(data, variable_manager=variable_manager, loader=loader)

@ -21,7 +21,7 @@ import os
from ansible.errors import AnsibleParserError, AnsibleError from ansible.errors import AnsibleParserError, AnsibleError
from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import FieldAttributeBase
from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.helpers import load_list_of_roles from ansible.playbook.helpers import load_list_of_roles
from ansible.playbook.role.requirement import RoleRequirement from ansible.playbook.role.requirement import RoleRequirement
@ -29,7 +29,7 @@ from ansible.playbook.role.requirement import RoleRequirement
__all__ = ['RoleMetadata'] __all__ = ['RoleMetadata']
class RoleMetadata(Base, CollectionSearch): class RoleMetadata(FieldAttributeBase, CollectionSearch):
""" """
This class wraps the parsing and validation of the optional metadata This class wraps the parsing and validation of the optional metadata
within each Role (meta/main.yml). within each Role (meta/main.yml).
@ -59,7 +59,7 @@ class RoleMetadata(Base, CollectionSearch):
def _load_dependencies(self, attr, ds): def _load_dependencies(self, attr, ds):
""" """
This is a helper loading function for the dependencies list, This is a helper loading function for the dependencies list,
which returns a list of RoleInclude objects which returns a list of RoleDefinition objects
""" """
roles = [] roles = []

@ -18,8 +18,6 @@
from __future__ import annotations from __future__ import annotations
from ansible.errors import AnsibleError from ansible.errors import AnsibleError
from ansible.playbook.role.definition import RoleDefinition
from ansible.utils.display import Display
from ansible.utils.galaxy import scm_archive_resource from ansible.utils.galaxy import scm_archive_resource
__all__ = ['RoleRequirement'] __all__ = ['RoleRequirement']
@ -32,19 +30,13 @@ VALID_SPEC_KEYS = [
'version', 'version',
] ]
display = Display()
class RoleRequirement(RoleDefinition):
class RoleRequirement:
""" """
Helper class for Galaxy, which is used to parse both dependencies Helper class for Galaxy, which is used to parse both dependencies
specified in meta/main.yml and requirements.yml files. specified in meta/main.yml and requirements.yml files.
""" """
def __init__(self):
pass
@staticmethod @staticmethod
def repo_url_to_role_name(repo_url): def repo_url_to_role_name(repo_url):
# gets the role name out of a repo like # gets the role name out of a repo like

@ -21,14 +21,11 @@ from ansible.errors import AnsibleError, AnsibleParserError
from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role import Role from ansible.playbook.role import Role
from ansible.playbook.role.include import RoleInclude from ansible.playbook.role.definition import RoleDefinition
from ansible.utils.display import Display
from ansible._internal._templating._engine import TemplateEngine from ansible._internal._templating._engine import TemplateEngine
__all__ = ['IncludeRole'] __all__ = ['IncludeRole']
display = Display()
class IncludeRole(TaskInclude): class IncludeRole(TaskInclude):
@ -59,6 +56,11 @@ class IncludeRole(TaskInclude):
self._parent_role = role self._parent_role = role
self._role_name = None self._role_name = None
self._role_path = None self._role_path = None
self.statically_loaded = False
@property
def _post_validate_object(self):
return not self.statically_loaded
def get_name(self): def get_name(self):
""" return the name of the task """ """ return the name of the task """
@ -73,13 +75,13 @@ class IncludeRole(TaskInclude):
myplay = play myplay = play
try: try:
ri = RoleInclude.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections) rd = RoleDefinition.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections)
except AnsibleError as e: except AnsibleError as e:
if not self.rescuable: if not self.rescuable:
raise AnsibleParserError("Could not include role.") from e raise AnsibleParserError("Could not include role.") from e
raise raise
ri.vars |= self.vars rd.vars |= self.vars
if variable_manager is not None: if variable_manager is not None:
available_variables = variable_manager.get_vars(play=myplay, task=self) available_variables = variable_manager.get_vars(play=myplay, task=self)
@ -89,7 +91,7 @@ class IncludeRole(TaskInclude):
from_files = templar.template(self._from_files) from_files = templar.template(self._from_files)
# build role # build role
actual_role = Role.load(ri, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True, actual_role = Role.load(rd, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True,
validate=self.rolespec_validate, public=self.public, static=self.statically_loaded, rescuable=self.rescuable) validate=self.rolespec_validate, public=self.public, static=self.statically_loaded, rescuable=self.rescuable)
actual_role._metadata.allow_duplicates = self.allow_duplicates actual_role._metadata.allow_duplicates = self.allow_duplicates
@ -99,23 +101,19 @@ class IncludeRole(TaskInclude):
# save this for later use # save this for later use
self._role_path = actual_role._role_path self._role_path = actual_role._role_path
# compile role with parent roles as dependencies to ensure they inherit
# variables
dep_chain = actual_role.get_dep_chain()
p_block = self.build_parent_block() p_block = self.build_parent_block()
# collections value is not inherited; override with the value we calculated during role setup # collections value is not inherited; override with the value we calculated during role setup
p_block.collections = actual_role.collections p_block.collections = actual_role.collections
blocks = actual_role.compile(play=myplay, dep_chain=dep_chain) blocks = actual_role.compile(play=myplay)
for b in blocks: for b in blocks:
b._parent = p_block b._parent = p_block
# HACK: parent inheritance doesn't seem to have a way to handle this intermediate override until squashed/finalized # HACK: parent inheritance doesn't seem to have a way to handle this intermediate override until squashed/finalized
b.collections = actual_role.collections b.collections = actual_role.collections
# updated available handlers in play # updated available handlers in play
handlers = actual_role.get_handler_blocks(play=myplay, dep_chain=dep_chain) handlers = actual_role.get_handler_blocks(play=myplay)
for h in handlers: for h in handlers:
h._parent = p_block h._parent = p_block
myplay.handlers = myplay.handlers + handlers myplay.handlers = myplay.handlers + handlers

@ -71,7 +71,7 @@ class Taggable:
obj = obj._parent obj = obj._parent
yield self.get_play() yield self.play
def evaluate_tags(self, only_tags, skip_tags, all_vars): def evaluate_tags(self, only_tags, skip_tags, all_vars):
"""Check if the current item should be executed depending on the specified tags. """Check if the current item should be executed depending on the specified tags.

@ -20,7 +20,6 @@ from __future__ import annotations
import typing as t import typing as t
from ansible import constants as C from ansible import constants as C
from ansible.module_utils.common.sentinel import Sentinel
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError, AnsibleValueOmittedError from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError, AnsibleValueOmittedError
from ansible.executor.module_common import _get_action_arg_defaults from ansible.executor.module_common import _get_action_arg_defaults
from ansible.module_utils.common.text.converters import to_native from ansible.module_utils.common.text.converters import to_native
@ -30,7 +29,6 @@ from ansible.plugins.action import ActionBase
from ansible.plugins.loader import action_loader, module_loader, lookup_loader from ansible.plugins.loader import action_loader, module_loader, lookup_loader
from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
from ansible.playbook.block import Block
from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.delegatable import Delegatable from ansible.playbook.delegatable import Delegatable
@ -97,7 +95,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
""" constructors a task, without the Task.load classmethod, it will be pretty blank """ """ constructors a task, without the Task.load classmethod, it will be pretty blank """
self._role = role self._role = role
self._parent = None
self.implicit = False self.implicit = False
self._resolved_action: str | None = None self._resolved_action: str | None = None
@ -156,20 +153,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
else: else:
return "%s" % (self.action,) return "%s" % (self.action,)
def _merge_kv(self, ds):
if ds is None:
return ""
elif isinstance(ds, str):
return ds
elif isinstance(ds, dict):
buf = ""
for (k, v) in ds.items():
if k.startswith('_'):
continue
buf = buf + "%s=%s " % (k, v)
buf = buf.strip()
return buf
@staticmethod @staticmethod
def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None): def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None):
task = Task(block=block, role=role, task_include=task_include) task = Task(block=block, role=role, task_include=task_include)
@ -283,7 +266,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
else: else:
# Validate this untemplated field early on to guarantee we are dealing with a list. # Validate this untemplated field early on to guarantee we are dealing with a list.
# This is also done in CollectionSearch._load_collections() but this runs before that call. # This is also done in CollectionSearch._load_collections() but this runs before that call.
collections_list = self.get_validated_value('collections', self.fattributes.get('collections'), collections_list, None) collections_list = self.get_validated_value('collections', collections_list, None)
if default_collection and not self._role: # FIXME: and not a collections role if default_collection and not self._role: # FIXME: and not a collections role
if collections_list: if collections_list:
@ -375,17 +358,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
except Exception as ex: except Exception as ex:
raise AnsibleParserError("Invalid 'register' specified.", obj=value) from ex raise AnsibleParserError("Invalid 'register' specified.", obj=value) from ex
def post_validate(self, templar):
"""
Override of base class post_validate, to also do final validation on
the block and task include (if any) to which this task belongs.
"""
if self._parent:
self._parent.post_validate(templar)
super(Task, self).post_validate(templar)
def _post_validate_loop(self, attr, value, templar): def _post_validate_loop(self, attr, value, templar):
""" """
Override post validation for the loop field, which is templated Override post validation for the loop field, which is templated
@ -425,7 +397,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
return return
raise raise
# NB: the environment FieldAttribute definition ensures that value is always a list
for env_item in value: for env_item in value:
if isinstance(env_item, dict): if isinstance(env_item, dict):
for k in env_item: for k in env_item:
@ -471,10 +442,8 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
all_vars |= self.vars all_vars |= self.vars
if 'tags' in all_vars: all_vars.pop('tags', None)
del all_vars['tags'] all_vars.pop('when', None)
if 'when' in all_vars:
del all_vars['when']
return all_vars return all_vars
@ -482,8 +451,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
all_vars = dict() all_vars = dict()
if self._parent: if self._parent:
all_vars |= self._parent.get_include_params() all_vars |= self._parent.get_include_params()
if self.action in C._ACTION_ALL_INCLUDES:
all_vars |= self.vars
return all_vars return all_vars
def copy(self, exclude_parent: bool = False, exclude_tasks: bool = False) -> Task: def copy(self, exclude_parent: bool = False, exclude_tasks: bool = False) -> Task:
@ -493,13 +460,10 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
if self._parent and not exclude_parent: if self._parent and not exclude_parent:
new_me._parent = self._parent.copy(exclude_tasks=exclude_tasks) new_me._parent = self._parent.copy(exclude_tasks=exclude_tasks)
new_me._role = None
if self._role:
new_me._role = self._role new_me._role = self._role
new_me.implicit = self.implicit new_me.implicit = self.implicit
new_me._resolved_action = self._resolved_action new_me._resolved_action = self._resolved_action
new_me._uuid = self._uuid
return new_me return new_me
@ -515,51 +479,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
if self._parent: if self._parent:
self._parent.set_loader(loader) self._parent.set_loader(loader)
def _get_parent_attribute(self, attr, omit=False):
"""
Generic logic to get the attribute or parent attribute for a task value.
"""
fattr = self.fattributes[attr]
extend = fattr.extend
prepend = fattr.prepend
try:
# omit self, and only get parent values
if omit:
value = Sentinel
else:
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
_parent = self._parent
else:
_parent = self._parent._parent
if _parent and (value is Sentinel or extend):
if getattr(_parent, 'statically_loaded', True):
# vars are always inheritable, other attributes might not be for the parent but still should be for other ancestors
if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except KeyError:
pass
return value
def all_parents_static(self):
if self._parent:
return self._parent.all_parents_static()
return True
def get_first_parent_include(self): def get_first_parent_include(self):
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
if self._parent: if self._parent:
@ -568,12 +487,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
return self._parent.get_first_parent_include() return self._parent.get_first_parent_include()
return None return None
def get_play(self):
parent = self._parent
while not isinstance(parent, Block):
parent = parent._parent
return parent._play
def dump_attrs(self): def dump_attrs(self):
"""Override to smuggle important non-FieldAttribute values back to the controller.""" """Override to smuggle important non-FieldAttribute values back to the controller."""
attrs = super().dump_attrs() attrs = super().dump_attrs()
@ -585,10 +498,9 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
# from_attrs is only used to create a finalized task # from_attrs is only used to create a finalized task
# from attrs from the Worker/TaskExecutor # from attrs from the Worker/TaskExecutor
# Those attrs are finalized and squashed in the TE # Those attrs are finalized in the TE
# and controller side use needs to reflect that # and controller side use needs to reflect that
self._finalized = True self._finalized = True
self._squashed = True
def _resolve_conditional( def _resolve_conditional(
self, self,

@ -123,3 +123,8 @@ class TaskInclude(Task):
p_block = self p_block = self
return p_block return p_block
def get_include_params(self):
v = super().get_include_params()
v |= self.vars
return v

@ -542,7 +542,7 @@ class VariableManager:
delegated_vars['ansible_delegated_vars'] = { delegated_vars['ansible_delegated_vars'] = {
delegated_host_name: self.get_vars( delegated_host_name: self.get_vars(
play=task.get_play(), play=task.play,
host=delegated_host, host=delegated_host,
task=task, task=task,
include_hostvars=True, include_hostvars=True,

@ -1,6 +0,0 @@
- hosts: localhost
remote_user: a
user: b
tasks:
- debug:
msg: did not run

@ -8,20 +8,10 @@ ansible-playbook -i ../../inventory types.yml -v "$@"
# test timeout # test timeout
ansible-playbook -i ../../inventory timeout.yml -v "$@" ansible-playbook -i ../../inventory timeout.yml -v "$@"
# our Play class allows for 'user' or 'remote_user', but not both.
# first test that both user and remote_user work individually
set +e set +e
result="$(ansible-playbook -i ../../inventory user.yml -v "$@" 2>&1)" result="$(ansible-playbook -i ../../inventory user.yml -v "$@" 2>&1)"
set -e set -e
grep -q "worked with user" <<< "$result" grep -q "is not a valid attribute for a Play" <<< "$result"
grep -q "worked with remote_user" <<< "$result"
# then test that the play errors if user and remote_user both exist
echo "EXPECTED ERROR: Ensure we fail properly if a play has both user and remote_user."
set +e
result="$(ansible-playbook -i ../../inventory remote_user_and_user.yml -v "$@" 2>&1)"
set -e
grep -q "both 'user' and 'remote_user' are set for this play." <<< "$result"
# test that playbook errors if len(plays) == 0 # test that playbook errors if len(plays) == 0
echo "EXPECTED ERROR: Ensure we fail properly if a playbook is an empty list." echo "EXPECTED ERROR: Ensure we fail properly if a playbook is an empty list."

@ -13,11 +13,6 @@
- hosts: localhost - hosts: localhost
user: "{{ me }}" user: "{{ me }}"
tasks: tasks:
- debug: - name: should not happen
debug:
msg: worked with user ({{ me }}) msg: worked with user ({{ me }})
- hosts: localhost
remote_user: "{{ me }}"
tasks:
- debug:
msg: worked with remote_user ({{ me }})

@ -50,7 +50,6 @@ lib/ansible/module_utils/six/__init__.py pylint:trailing-comma-tuple
lib/ansible/module_utils/six/__init__.py pylint:unidiomatic-typecheck lib/ansible/module_utils/six/__init__.py pylint:unidiomatic-typecheck
lib/ansible/module_utils/six/__init__.py replace-urlopen lib/ansible/module_utils/six/__init__.py replace-urlopen
lib/ansible/module_utils/urls.py replace-urlopen lib/ansible/module_utils/urls.py replace-urlopen
lib/ansible/playbook/role/include.py pylint:arguments-renamed
lib/ansible/plugins/action/normal.py action-plugin-docs # default action plugin for modules without a dedicated action plugin lib/ansible/plugins/action/normal.py action-plugin-docs # default action plugin for modules without a dedicated action plugin
lib/ansible/plugins/cache/base.py ansible-doc!skip # not a plugin, but a stub for backwards compatibility lib/ansible/plugins/cache/base.py ansible-doc!skip # not a plugin, but a stub for backwards compatibility
lib/ansible/plugins/callback/__init__.py pylint:arguments-renamed lib/ansible/plugins/callback/__init__.py pylint:arguments-renamed

@ -31,7 +31,7 @@ from units.mock.loader import DictDataLoader
from units.mock.path import mock_unfrackpath_noop from units.mock.path import mock_unfrackpath_noop
from ansible.playbook.role import Role from ansible.playbook.role import Role
from ansible.playbook.role.include import RoleInclude from ansible.playbook.role.definition import RoleDefinition
from ansible.playbook.role import hash_params from ansible.playbook.role import hash_params
@ -168,7 +168,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(str(r), 'foo_tasks') self.assertEqual(str(r), 'foo_tasks')
@ -190,7 +190,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play, from_files=dict(tasks='custom_main')) r = Role.load(i, play=mock_play, from_files=dict(tasks='custom_main'))
self.assertEqual(r._task_blocks[0]._ds[0]['command'], 'baz') self.assertEqual(r._task_blocks[0]._ds[0]['command'], 'baz')
@ -208,7 +208,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_handlers', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_handlers', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(len(r._handler_blocks), 1) self.assertEqual(len(r._handler_blocks), 1)
@ -229,7 +229,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar')) self.assertEqual(r._default_vars, dict(foo='bar'))
@ -250,7 +250,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar')) self.assertEqual(r._default_vars, dict(foo='bar'))
@ -271,7 +271,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar')) self.assertEqual(r._default_vars, dict(foo='bar'))
@ -294,7 +294,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar', a=1, b=2)) self.assertEqual(r._default_vars, dict(foo='bar', a=1, b=2))
@ -314,7 +314,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r._role_vars, dict(foo='bam')) self.assertEqual(r._role_vars, dict(foo='bam'))
@ -361,7 +361,7 @@ class TestRole(unittest.TestCase):
mock_play.collections = None mock_play.collections = None
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load('foo_metadata', play=mock_play, loader=fake_loader) i = RoleDefinition.load('foo_metadata', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
role_deps = r.get_direct_dependencies() role_deps = r.get_direct_dependencies()
@ -379,16 +379,16 @@ class TestRole(unittest.TestCase):
self.assertEqual(all_deps[1].get_name(), 'baz_metadata') self.assertEqual(all_deps[1].get_name(), 'baz_metadata')
self.assertEqual(all_deps[2].get_name(), 'bar_metadata') self.assertEqual(all_deps[2].get_name(), 'bar_metadata')
i = RoleInclude.load('bad1_metadata', play=mock_play, loader=fake_loader) i = RoleDefinition.load('bad1_metadata', play=mock_play, loader=fake_loader)
self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play) self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play)
i = RoleInclude.load('bad2_metadata', play=mock_play, loader=fake_loader) i = RoleDefinition.load('bad2_metadata', play=mock_play, loader=fake_loader)
self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play) self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play)
# TODO: re-enable this test once Ansible has proper role dep cycle detection # TODO: re-enable this test once Ansible has proper role dep cycle detection
# that doesn't rely on stack overflows being recoverable (as they aren't in Py3.7+) # that doesn't rely on stack overflows being recoverable (as they aren't in Py3.7+)
# see https://github.com/ansible/ansible/issues/61527 # see https://github.com/ansible/ansible/issues/61527
# i = RoleInclude.load('recursive1_metadata', play=mock_play, loader=fake_loader) # i = RoleDefinition.load('recursive1_metadata', play=mock_play, loader=fake_loader)
# self.assertRaises(AnsibleError, Role.load, i, play=mock_play) # self.assertRaises(AnsibleError, Role.load, i, play=mock_play)
@patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop)
@ -406,7 +406,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.role_cache = {} mock_play.role_cache = {}
i = RoleInclude.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader) i = RoleDefinition.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play) r = Role.load(i, play=mock_play)
self.assertEqual(r.get_name(), "foo_complex") self.assertEqual(r.get_name(), "foo_complex")

@ -27,30 +27,7 @@ class TestAttribute(unittest.TestCase):
self.one = Attribute(priority=100) self.one = Attribute(priority=100)
self.two = Attribute(priority=0) self.two = Attribute(priority=0)
def test_eq(self):
self.assertTrue(self.one == self.one)
self.assertFalse(self.one == self.two)
def test_ne(self):
self.assertFalse(self.one != self.one)
self.assertTrue(self.one != self.two)
def test_lt(self): def test_lt(self):
self.assertFalse(self.one < self.one) self.assertFalse(self.one < self.one)
self.assertTrue(self.one < self.two) self.assertTrue(self.one < self.two)
self.assertFalse(self.two < self.one) self.assertFalse(self.two < self.one)
def test_gt(self):
self.assertFalse(self.one > self.one)
self.assertFalse(self.one > self.two)
self.assertTrue(self.two > self.one)
def test_le(self):
self.assertTrue(self.one <= self.one)
self.assertTrue(self.one <= self.two)
self.assertFalse(self.two <= self.one)
def test_ge(self):
self.assertTrue(self.one >= self.one)
self.assertFalse(self.one >= self.two)
self.assertTrue(self.two >= self.one)

@ -47,8 +47,6 @@ class TestBase(unittest.TestCase):
bsc = self.ClassUnderTest() bsc = self.ClassUnderTest()
parent = ExampleParentBaseSubClass() parent = ExampleParentBaseSubClass()
bsc._parent = parent bsc._parent = parent
bsc._dep_chain = [parent]
parent._dep_chain = None
bsc.load_data(ds) bsc.load_data(ds)
fake_loader = DictDataLoader({}) fake_loader = DictDataLoader({})
templar = TemplateEngine(loader=fake_loader) templar = TemplateEngine(loader=fake_loader)
@ -110,11 +108,8 @@ class TestBase(unittest.TestCase):
def test_load_data_invalid_attr_type(self): def test_load_data_invalid_attr_type(self):
ds = {'environment': True} ds = {'environment': True}
# environment is supposed to be a list. This
# seems like it shouldn't work?
ret = self.b.load_data(ds) ret = self.b.load_data(ds)
self.assertEqual(True, ret._environment) self.assertEqual([True], ret._environment)
def test_post_validate(self): def test_post_validate(self):
ds = {'environment': [], ds = {'environment': [],
@ -170,10 +165,6 @@ class TestBase(unittest.TestCase):
b = self._base_validate(ds) b = self._base_validate(ds)
self.assertEqual(b.vars, {}) self.assertEqual(b.vars, {})
def test_validate_empty(self):
self.b.validate()
self.assertTrue(self.b._validated)
def test_getters(self): def test_getters(self):
# not sure why these exist, but here are tests anyway # not sure why these exist, but here are tests anyway
loader = self.b.get_loader() loader = self.b.get_loader()
@ -182,70 +173,6 @@ class TestBase(unittest.TestCase):
self.assertEqual(variable_manager, self.b._variable_manager) self.assertEqual(variable_manager, self.b._variable_manager)
class TestExtendValue(unittest.TestCase):
# _extend_value could be a module or staticmethod but since its
# not, the test is here.
def test_extend_value_list_newlist(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(value_list, new_value_list)
self.assertEqual(value_list + new_value_list, ret)
def test_extend_value_list_newlist_prepend(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret_prepend = b._extend_value(value_list, new_value_list, prepend=True)
self.assertEqual(new_value_list + value_list, ret_prepend)
def test_extend_value_newlist_list(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(new_value_list, value_list)
self.assertEqual(new_value_list + value_list, ret)
def test_extend_value_newlist_list_prepend(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(new_value_list, value_list, prepend=True)
self.assertEqual(value_list + new_value_list, ret)
def test_extend_value_string_newlist(self):
b = base.Base()
some_string = 'some string'
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(some_string, new_value_list)
self.assertEqual([some_string] + new_value_list, ret)
def test_extend_value_string_newstring(self):
b = base.Base()
some_string = 'some string'
new_value_string = 'this is the new values'
ret = b._extend_value(some_string, new_value_string)
self.assertEqual([some_string, new_value_string], ret)
def test_extend_value_list_newstring(self):
b = base.Base()
value_list = ['first', 'second']
new_value_string = 'this is the new values'
ret = b._extend_value(value_list, new_value_string)
self.assertEqual(value_list + [new_value_string], ret)
def test_extend_value_none_none(self):
b = base.Base()
ret = b._extend_value(None, None)
self.assertEqual(len(ret), 0)
self.assertFalse(ret)
def test_extend_value_none_list(self):
b = base.Base()
ret = b._extend_value(None, ['foo'])
self.assertEqual(ret, ['foo'])
class ExampleException(Exception): class ExampleException(Exception):
pass pass
@ -255,12 +182,7 @@ class ExampleParentBaseSubClass(base.Base):
test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing') test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing')
def __init__(self): def __init__(self):
super(ExampleParentBaseSubClass, self).__init__() super(ExampleParentBaseSubClass, self).__init__()
self._dep_chain = None
def get_dep_chain(self):
return self._dep_chain
class ExampleSubClass(base.Base): class ExampleSubClass(base.Base):
@ -281,7 +203,7 @@ class BaseSubClass(base.Base):
test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True) test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True)
test_attr_list_required = FieldAttribute(isa='list', listof=(str,), required=True, test_attr_list_required = FieldAttribute(isa='list', listof=(str,), required=True,
default=list, always_post_validate=True) default=list, always_post_validate=True)
test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value') test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value', always_post_validate=True)
test_attr_string_required = FieldAttribute(isa='string', required=True, test_attr_string_required = FieldAttribute(isa='string', required=True,
default='the_test_attr_string_default_value') default='the_test_attr_string_default_value')
test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True) test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True)
@ -299,9 +221,6 @@ class BaseSubClass(base.Base):
test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter', test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter',
always_post_validate=True) always_post_validate=True)
def _get_attr_test_attr_method(self):
return 'foo bar'
def _validate_test_attr_example(self, attr, name, value): def _validate_test_attr_example(self, attr, name, value):
if not isinstance(value, str): if not isinstance(value, str):
raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value))) raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value)))
@ -333,13 +252,6 @@ class TestBaseSubClass(TestBase):
bsc = self._base_validate(ds) bsc = self._base_validate(ds)
self.assertEqual(bsc.test_attr_int, MOST_RANDOM_NUMBER) self.assertEqual(bsc.test_attr_int, MOST_RANDOM_NUMBER)
def test_attr_int_del(self):
MOST_RANDOM_NUMBER = 37
ds = {'test_attr_int': MOST_RANDOM_NUMBER}
bsc = self._base_validate(ds)
del bsc.test_attr_int
self.assertNotIn('_test_attr_int', bsc.__dict__)
def test_attr_float(self): def test_attr_float(self):
roughly_pi = 4.0 roughly_pi = 4.0
ds = {'test_attr_float': roughly_pi} ds = {'test_attr_float': roughly_pi}
@ -446,7 +358,7 @@ class TestBaseSubClass(TestBase):
def test_attr_string_invalid_list(self): def test_attr_string_invalid_list(self):
ds = {'test_attr_string': ['The new test_attr_string', 'value, however in a list']} ds = {'test_attr_string': ['The new test_attr_string', 'value, however in a list']}
self.assertRaises(AnsibleParserError, self._base_validate, ds) self.assertRaises(AnsibleFieldAttributeError, self._base_validate, ds)
def test_attr_string_required(self): def test_attr_string_required(self):
the_string_value = "the new test_attr_string_required_value" the_string_value = "the new test_attr_string_required_value"
@ -512,12 +424,6 @@ class TestBaseSubClass(TestBase):
{'test_attr_unknown_isa': True} {'test_attr_unknown_isa': True}
) )
def test_attr_method(self):
ds = {'test_attr_method': 'value from the ds'}
bsc = self._base_validate(ds)
# The value returned by the subclasses _get_attr_test_attr_method
self.assertEqual(bsc.test_attr_method, 'foo bar')
def test_attr_method_missing(self): def test_attr_method_missing(self):
a_string = 'The value set from the ds' a_string = 'The value set from the ds'
ds = {'test_attr_method_missing': a_string} ds = {'test_attr_method_missing': a_string}
@ -525,10 +431,9 @@ class TestBaseSubClass(TestBase):
self.assertEqual(bsc.test_attr_method_missing, a_string) self.assertEqual(bsc.test_attr_method_missing, a_string)
def test_get_validated_value_string_preserve_tags(self): def test_get_validated_value_string_preserve_tags(self):
attribute = FieldAttribute(isa='string')
value = TrustedAsTemplate().tag('bar') value = TrustedAsTemplate().tag('bar')
templar = TemplateEngine(None) templar = TemplateEngine(None)
bsc = self.ClassUnderTest() bsc = self.ClassUnderTest()
result = bsc.get_validated_value('foo', attribute, value, templar) result = bsc.get_validated_value('test_attr_string', value, templar)
assert TrustedAsTemplate.is_tagged_on(result) assert TrustedAsTemplate.is_tagged_on(result)
assert result == 'bar' assert result == 'bar'

@ -32,7 +32,7 @@ from ansible.playbook.block import Block
from ansible.playbook.handler import Handler from ansible.playbook.handler import Handler
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.playbook.task_include import TaskInclude from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role.include import RoleInclude from ansible.playbook.role.definition import RoleDefinition
class MixinForMocks(object): class MixinForMocks(object):
@ -317,7 +317,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks):
variable_manager=self.mock_variable_manager, loader=self.fake_role_loader) variable_manager=self.mock_variable_manager, loader=self.fake_role_loader)
self.assertIsInstance(res, list) self.assertIsInstance(res, list)
for r in res: for r in res:
self.assertIsInstance(r, RoleInclude) self.assertIsInstance(r, RoleDefinition)
def test_block_unknown_action(self): def test_block_unknown_action(self):
ds = [{ ds = [{
@ -328,7 +328,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks):
variable_manager=self.mock_variable_manager, loader=self.fake_role_loader) variable_manager=self.mock_variable_manager, loader=self.fake_role_loader)
self.assertIsInstance(res, list) self.assertIsInstance(res, list)
for r in res: for r in res:
self.assertIsInstance(r, RoleInclude) self.assertIsInstance(r, RoleDefinition)
@pytest.mark.usefixtures('collection_loader') @pytest.mark.usefixtures('collection_loader')

@ -50,17 +50,6 @@ def test_basic_play():
assert p.connection == 'local' assert p.connection == 'local'
def test_play_with_remote_user():
p = Play.load(dict(
name="test play",
hosts=['foo'],
user="testing",
gather_facts=False,
))
assert p.remote_user == "testing"
def test_play_with_user_conflict(): def test_play_with_user_conflict():
play_data = dict( play_data = dict(
name="test play", name="test play",
@ -101,7 +90,6 @@ def test_play_with_handlers():
)) ))
assert len(p.handlers) >= 1 assert len(p.handlers) >= 1
assert len(p.get_handlers()) >= 1
assert isinstance(p.handlers[0], Block) assert isinstance(p.handlers[0], Block)
assert p.handlers[0].has_tasks() is True assert p.handlers[0].has_tasks() is True
@ -118,9 +106,8 @@ def test_play_with_pre_tasks():
assert isinstance(p.pre_tasks[0], Block) assert isinstance(p.pre_tasks[0], Block)
assert p.pre_tasks[0].has_tasks() is True assert p.pre_tasks[0].has_tasks() is True
assert len(p.get_tasks()) >= 1 assert isinstance(p.pre_tasks[0].block[0], Task)
assert isinstance(p.get_tasks()[0][0], Task) assert p.pre_tasks[0].block[0].action == 'shell'
assert p.get_tasks()[0][0].action == 'shell'
def test_play_with_post_tasks(): def test_play_with_post_tasks():
@ -158,7 +145,7 @@ def test_play_with_roles(mocker):
blocks = p.compile() blocks = p.compile()
assert len(blocks) > 1 assert len(blocks) > 1
assert all(isinstance(block, Block) for block in blocks) assert all(isinstance(block, Block) for block in blocks)
assert isinstance(p.get_roles()[0], Role) assert isinstance(p.roles[0], Role)
def test_play_compile(): def test_play_compile():

@ -28,9 +28,10 @@ class TaggableTestObj(Taggable):
self._loader = DictDataLoader({}) self._loader = DictDataLoader({})
self.tags = [] self.tags = []
self._parent = None self._parent = None
self.play = None
def get_play(self): def finalized(self):
return None return False
class TestTaggable(unittest.TestCase): class TestTaggable(unittest.TestCase):

@ -56,7 +56,7 @@ class TestTask(unittest.TestCase):
p = dict(delay=delay) p = dict(delay=delay)
p.update(task_base) p.update(task_base)
t = Task().load_data(p) t = Task().load_data(p)
self.assertEqual(t.get_validated_value('delay', t.fattributes.get('delay'), delay, None), expected) self.assertEqual(t.get_validated_value('delay', delay, None), expected)
bad_params = [ bad_params = [
'E', 'E',
@ -69,7 +69,7 @@ class TestTask(unittest.TestCase):
p.update(task_base) p.update(task_base)
t = Task().load_data(p) t = Task().load_data(p)
with self.assertRaises(AnsibleError): with self.assertRaises(AnsibleError):
dummy = t.get_validated_value('delay', t.fattributes.get('delay'), delay, None) dummy = t.get_validated_value('delay', delay, None)
def test_task_auto_name_with_role(self): def test_task_auto_name_with_role(self):
pass pass

@ -84,7 +84,6 @@ class TestVariableManager(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.get_vars.return_value = dict(foo="bar") mock_play.get_vars.return_value = dict(foo="bar")
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = [] mock_play.get_vars_files.return_value = []
mock_inventory = MagicMock() mock_inventory = MagicMock()
@ -103,7 +102,6 @@ class TestVariableManager(unittest.TestCase):
mock_play = MagicMock() mock_play = MagicMock()
mock_play.get_vars.return_value = dict() mock_play.get_vars.return_value = dict()
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = [__file__] mock_play.get_vars_files.return_value = [__file__]
mock_inventory = MagicMock() mock_inventory = MagicMock()
@ -158,10 +156,25 @@ class TestVariableManager(unittest.TestCase):
# and role2 depend on common-role. Check that the tasks see # and role2 depend on common-role. Check that the tasks see
# different values of role_var. # different values of role_var.
blocks = play1.compile() blocks = play1.compile()
# compile returns the following layout of blocks
# (fact gathering is missing as that is added by PlayIterator):
# [TASK: meta (flush_handlers)]
# [TASK: common-role : debug]
# [TASK: meta (role_complete)]
# [TASK: meta (role_complete)]
# [TASK: common-role : debug]
# [TASK: meta (role_complete)]
# [TASK: meta (role_complete)]
# [TASK: meta (flush_handlers)]
# [TASK: meta (flush_handlers)]
task = blocks[1].block[0] task = blocks[1].block[0]
assert task.action == 'debug'
res = v.get_vars(play=play1, task=task) res = v.get_vars(play=play1, task=task)
self.assertEqual(res['role_var'], 'role_var_from_role1') assert res['role_var'] == 'role_var_from_role1'
task = blocks[2].block[0] task = blocks[4].block[0]
assert task.action == 'debug'
res = v.get_vars(play=play1, task=task) res = v.get_vars(play=play1, task=task)
self.assertEqual(res['role_var'], 'role_var_from_role2') assert res['role_var'] == 'role_var_from_role2'

Loading…
Cancel
Save