Move resolving parent attributes to FA + cleanup

ci_complete
pull/86158/head
Martin Krizek 3 weeks ago
parent 6bb7bd760f
commit 45d777f6ff

@ -0,0 +1,2 @@
minor_changes:
- internals - simplify ``dep_chain`` handling in playbook objects.

@ -77,8 +77,6 @@ class TaskExecutor:
self._loop_eval_error = None
self._task_templar = TemplateEngine(loader=self._loader, variables=self._job_vars)
self._task.squash()
def run(self):
"""
The main executor entrypoint, where we determine if the specified
@ -369,7 +367,6 @@ class TaskExecutor:
if self._task.loop_control and self._task.loop_control.break_when:
break_when = self._task.loop_control.get_validated_value(
'break_when',
self._task.loop_control.fattributes.get('break_when'),
self._task.loop_control.break_when,
templar,
)

@ -90,41 +90,13 @@ class Attribute:
def __set_name__(self, owner, name):
self.name = name
def __eq__(self, other):
return other.priority == self.priority
def __ne__(self, other):
return other.priority != self.priority
# NB: higher priority numbers sort first
# __lt__ is sufficient for sorted() which is our only use case
def __lt__(self, other):
return other.priority < self.priority
def __gt__(self, other):
return other.priority > self.priority
def __le__(self, other):
return other.priority <= self.priority
def __ge__(self, other):
return other.priority >= self.priority
def __get__(self, obj: FieldAttributeBase, obj_type=None):
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not used in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
if (value := getattr(obj, f'_{self.name}', Sentinel)) is Sentinel:
value = self.default
if callable(value):
value = value()
@ -137,17 +109,30 @@ class Attribute:
if self.alias is not None:
setattr(obj, f'_{self.alias}', value)
# NOTE this appears to be not needed in the codebase,
# leaving it here for test_attr_int_del from
# test/units/playbook/test_base.py to pass.
def __delete__(self, obj):
delattr(obj, f'_{self.name}')
class NonInheritableFieldAttribute(Attribute):
...
def _get_parent_static_chain(obj):
# NOTE similar code in Taggable
# FIXME this encapsulates the mess caused by not having proper parent chain on playbook objects
parent = getattr(obj, '_parent', None)
while parent:
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(parent, 'statically_loaded', True):
yield parent
parent = getattr(parent, '_parent', None)
if role := getattr(obj, '_role', None):
yield obj._role
if dep_chain := obj.get_dep_chain():
yield from reversed(dep_chain)
yield obj.play
class FieldAttribute(Attribute):
def __init__(self, extend=False, prepend=False, **kwargs):
super().__init__(**kwargs)
@ -156,24 +141,27 @@ class FieldAttribute(Attribute):
self.prepend = prepend
def __get__(self, obj, obj_type=None):
if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
try:
value = obj._get_parent_attribute(self.name)
except AttributeError:
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not needed in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
if not obj.finalized:
if self.extend:
# NOTE implements the historic behavior of now removed Base._extend_value
all_values = value if value not in (Sentinel, None) else []
_extend = all_values.extend
for parent in _get_parent_static_chain(obj):
if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) not in (Sentinel, None):
if self.prepend:
all_values[:0] = parent_value
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
_extend(parent_value)
# some FAs contain non-hashable values, skip dedup in such a case
# FIXME deal with list of dicts differently?
value = all_values if self.name in ('module_defaults', 'environment') else list(dict.fromkeys(all_values))
elif value is Sentinel:
for parent in _get_parent_static_chain(obj):
if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) is not Sentinel:
value = parent_value
break
if value is Sentinel:
value = self.default

@ -5,7 +5,6 @@
from __future__ import annotations
import decimal
import itertools
import operator
import os
@ -30,6 +29,9 @@ from ansible.utils.display import Display
from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name
from ansible._internal._templating._engine import TemplateEngine
if t.TYPE_CHECKING:
from ansible.playbook.role import Role
display = Display()
@ -112,13 +114,13 @@ class FieldAttributeBase:
self._origin: Origin | None = None
# other internal params
self._validated = False
self._squashed = False
self._finalized = False
# every object gets a random uuid:
self._uuid = get_unique_id()
self._ds = None
@property
def finalized(self):
return self._finalized
@ -130,9 +132,7 @@ class FieldAttributeBase:
display.debug("%s- %s (%s, id=%s)" % (" " * depth, self.__class__.__name__, self, id(self)))
if hasattr(self, '_parent') and self._parent:
self._parent.dump_me(depth + 2)
dep_chain = self._parent.get_dep_chain()
if dep_chain:
for dep in dep_chain:
for dep in self._parent.get_dep_chain():
dep.dump_me(depth + 2)
if hasattr(self, '_play') and self._play:
self._play.dump_me(depth + 2)
@ -148,7 +148,7 @@ class FieldAttributeBase:
raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds)
# cache the datastructure internally
setattr(self, '_ds', ds)
self._ds = ds
# the variable manager class is used to manage and merge variables
# down to a single dictionary for reference in templating, etc.
@ -185,10 +185,7 @@ class FieldAttributeBase:
return self
def get_ds(self):
try:
return getattr(self, '_ds')
except AttributeError:
return None
return self._ds
def get_loader(self):
return self._loader
@ -218,26 +215,14 @@ class FieldAttributeBase:
if key not in valid_attrs:
raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=key)
def validate(self, all_vars=None):
def validate(self):
""" validation that is done at parse time, not load time """
if not self._validated:
# walk all fields in the object
for (name, attribute) in self.fattributes.items():
# run validator only if present
method = getattr(self, '_validate_%s' % name, None)
if method:
method(attribute, name, getattr(self, name))
else:
# and make sure the attribute is of the type it should be
value = getattr(self, f'_{name}', Sentinel)
if value is not None:
if attribute.isa == 'string' and isinstance(value, (list, dict)):
raise AnsibleParserError(
"The field '%s' is supposed to be a string type,"
" however the incoming data structure is a %s" % (name, type(value)), obj=self.get_ds()
)
for name, attribute in self.fattributes.items():
if (value := getattr(self, f'_{name}', Sentinel)) is Sentinel:
continue
self._validated = True
if method := getattr(self, f'_validate_{name}', None):
method(attribute, name, value)
def _load_module_defaults(self, name, value):
if value is None:
@ -297,8 +282,8 @@ class FieldAttributeBase:
@property
def play(self):
if hasattr(self, '_play'):
play = self._play
if _play := getattr(self, '_play', None):
play = _play
elif hasattr(self, '_parent') and hasattr(self._parent, '_play'):
play = self._parent._play
else:
@ -410,17 +395,6 @@ class FieldAttributeBase:
raise AnsibleParserError("Could not resolve action %s in module_defaults" % action_name)
display.vvvvv("Could not resolve action %s in module_defaults" % action_name)
def squash(self):
"""
Evaluates all attributes and sets them to the evaluated version,
so that all future accesses of attributes do not need to evaluate
parent attributes.
"""
if not self._squashed:
for name in self.fattributes:
setattr(self, name, getattr(self, name))
self._squashed = True
def copy(self):
"""
Create a copy of this object and return it.
@ -437,17 +411,17 @@ class FieldAttributeBase:
new_me._loader = self._loader
new_me._variable_manager = self._variable_manager
new_me._origin = self._origin
new_me._validated = self._validated
new_me._finalized = self._finalized
new_me._uuid = self._uuid
# if the ds value was set on the object, copy it to the new copy too
if hasattr(self, '_ds'):
new_me._ds = self._ds
if _ds := self.get_ds():
new_me._ds = _ds
return new_me
def get_validated_value(self, name, attribute, value, templar):
def get_validated_value(self, name: str, value: object, templar: TemplateEngine):
attribute: Attribute = self.fattributes[name]
try:
return self._get_validated_value(name, attribute, value, templar)
except (TypeError, ValueError):
@ -455,6 +429,13 @@ class FieldAttributeBase:
def _get_validated_value(self, name, attribute, value, templar):
if attribute.isa == 'string':
if isinstance(value, (list, dict)):
# NOTE historically this check has been in validate()
raise AnsibleParserError(
message=f"The field {name!r} is supposed to be a string type, "
f"however the incoming data structure is a {type(value)}",
obj=self.get_ds(),
)
value = to_text(value)
elif attribute.isa == 'int':
if not isinstance(value, int):
@ -509,20 +490,8 @@ class FieldAttributeBase:
def set_to_context(self, name: str) -> t.Any:
""" set to parent inherited value or Sentinel as appropriate"""
attribute = self.fattributes[name]
if isinstance(attribute, NonInheritableFieldAttribute):
# setting to sentinel will trigger 'default/default()' on getter
value = Sentinel
else:
try:
value = self._get_parent_attribute(name, omit=True)
except AttributeError:
# mostly playcontext as only tasks/handlers/blocks really resolve parent
value = Sentinel
setattr(self, name, value)
return value
setattr(self, name, Sentinel)
return getattr(self, name, Sentinel)
def post_validate(self, templar):
"""
@ -541,33 +510,23 @@ class FieldAttributeBase:
self._finalized = True
def post_validate_attribute(self, name: str, *, templar: TemplateEngine):
attribute: FieldAttribute = self.fattributes[name]
attribute: Attribute = self.fattributes[name]
# DTFIX-FUTURE: this can probably be used in many getattr cases below, but the value may be out-of-date in some cases
original_value = getattr(self, name) # we save this original (likely Origin-tagged) value to pass as `obj` for errors
if attribute.static:
value = getattr(self, name)
# we don't template 'vars' but allow template as values for later use
if name not in ('vars',) and templar.is_template(value):
if name not in ('vars',) and templar.is_template(original_value):
display.warning('"%s" is not templatable, but we found: %s, '
'it will not be templated and will be used "as is".' % (name, value))
'it will not be templated and will be used "as is".' % (name, original_value))
return Sentinel
if getattr(self, name) is None:
if original_value is None:
if not attribute.required:
return Sentinel
raise AnsibleFieldAttributeError(f'The field {name!r} is required but was not set.', obj=self.get_ds())
from .role_include import IncludeRole
if not attribute.always_post_validate and isinstance(self, IncludeRole) and self.statically_loaded: # import_role
# normal field attributes should not go through post validation on import_role/import_tasks
# only import_role is checked here because import_tasks never reaches this point
return Sentinel
# Skip post validation unless always_post_validate is True, or the object requires post validation.
if not attribute.always_post_validate and not self._post_validate_object:
# Intermediate objects like Play() won't have their fields validated by
@ -581,13 +540,13 @@ class FieldAttributeBase:
method = getattr(self, '_post_validate_%s' % name, None)
if method:
value = method(attribute, getattr(self, name), templar)
value = method(attribute, original_value, templar)
elif attribute.isa == 'class':
value = getattr(self, name)
value = original_value
else:
try:
# if the attribute contains a variable, template it now
value = templar.template(getattr(self, name))
value = templar.template(original_value)
except AnsibleValueOmittedError:
# If this evaluated to the omit value, set the value back to inherited by context
# or default specified in the FieldAttribute and move on
@ -598,7 +557,7 @@ class FieldAttributeBase:
# and make sure the attribute is of the type it should be
if value is not None:
value = self.get_validated_value(name, attribute, value, templar)
value = self.get_validated_value(name, value, templar)
# returning the value results in assigning the massaged value back to the attribute field
return value
@ -627,31 +586,6 @@ class FieldAttributeBase:
except TypeError as ex:
raise AnsibleParserError(f"Invalid variable name in vars specified for {self.__class__.__name__}.", obj=ds) from ex
def _extend_value(self, value, new_value, prepend=False):
"""
Will extend the value given with new_value (and will turn both
into lists if they are not so already). The values are run through
a set to remove duplicate values.
"""
if not isinstance(value, list):
value = [value]
if not isinstance(new_value, list):
new_value = [new_value]
# Due to where _extend_value may run for some attributes
# it is possible to end up with Sentinel in the list of values
# ensure we strip them
value = [v for v in value if v is not Sentinel]
new_value = [v for v in new_value if v is not Sentinel]
if prepend:
combined = new_value + value
else:
combined = value + new_value
return [i for i, dummy in itertools.groupby(combined) if i is not None]
def dump_attrs(self):
"""
Dumps all attributes to a dictionary
@ -719,8 +653,9 @@ class Base(FieldAttributeBase):
become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
# used to hold sudo/su stuff
DEPRECATED_ATTRIBUTES = [] # type: list[str]
def _validate_environment(self, attr, name, value):
if not isinstance(value, list):
setattr(self, name, [value])
def update_result_no_log(self, templar: TemplateEngine, result: dict[str, t.Any]) -> None:
"""Set the post-validated no_log value for the result, falling back to a default on validation/templating failure with a warning."""
@ -761,12 +696,11 @@ class Base(FieldAttributeBase):
return path
def get_dep_chain(self):
if hasattr(self, '_parent') and self._parent:
return self._parent.get_dep_chain()
def get_dep_chain(self) -> list[Role]:
if role := getattr(self, '_role', None):
return role.get_dep_chain() + [role]
else:
return None
return []
def get_search_path(self):
"""
@ -775,9 +709,8 @@ class Base(FieldAttributeBase):
"""
path_stack = []
dep_chain = self.get_dep_chain()
# inside role: add the dependency chain from current to dependent
if dep_chain:
if dep_chain := self.get_dep_chain():
path_stack.extend(reversed([x._role_path for x in dep_chain if hasattr(x, '_role_path')]))
# add path of task itself, unless it is already in the list

@ -18,7 +18,6 @@
from __future__ import annotations
from ansible.errors import AnsibleParserError
from ansible.module_utils.common.sentinel import Sentinel
from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional
@ -40,13 +39,11 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
# similar to the 'else' clause for exceptions
# otherwise = FieldAttribute(isa='list')
def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False):
def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False):
self._play = play
self._role = role
self._parent = None
self._dep_chain = None
self._use_handlers = use_handlers
self._implicit = implicit
if task_include:
self._parent = task_include
@ -77,25 +74,18 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
if self._parent:
all_vars |= self._parent.get_vars()
all_vars |= self.vars.copy()
all_vars |= self.vars
return all_vars
@staticmethod
def load(data, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None):
implicit = not Block.is_block(data)
b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers, implicit=implicit)
b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers)
return b.load_data(data, variable_manager=variable_manager, loader=loader)
@staticmethod
def is_block(ds):
is_block = False
if isinstance(ds, dict):
for attr in ('block', 'rescue', 'always'):
if attr in ds:
is_block = True
break
return is_block
return isinstance(ds, dict) and 'block' in ds
def preprocess_data(self, ds):
"""
@ -111,8 +101,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
return super(Block, self).preprocess_data(ds)
# FIXME: these do nothing but augment the exception message; DRY and nuke
def _load_block(self, attr, ds):
try:
return load_list_of_tasks(
@ -126,37 +114,9 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
use_handlers=self._use_handlers,
)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading a block", obj=self._ds) from ex
raise AnsibleParserError(f"A malformed block was encountered while loading a {attr}", obj=self._ds) from ex
def _load_rescue(self, attr, ds):
try:
return load_list_of_tasks(
ds,
play=self._play,
block=self,
role=self._role,
task_include=None,
variable_manager=self._variable_manager,
loader=self._loader,
use_handlers=self._use_handlers,
)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading rescue.", obj=self._ds) from ex
def _load_always(self, attr, ds):
try:
return load_list_of_tasks(
ds,
play=self._play,
block=self,
role=self._role,
task_include=None,
variable_manager=self._variable_manager,
loader=self._loader,
use_handlers=self._use_handlers,
)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading always", obj=self._ds) from ex
_load_rescue = _load_always = _load_block
def _validate_always(self, attr, name, value):
if value and not self.block:
@ -164,15 +124,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
_validate_rescue = _validate_always
def get_dep_chain(self):
if self._dep_chain is None:
if self._parent:
return self._parent.get_dep_chain()
else:
return None
else:
return self._dep_chain[:]
def copy(self, exclude_parent=False, exclude_tasks=False):
def _dupe_task_list(task_list, new_block):
new_task_list = []
@ -199,9 +150,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
new_me._play = self._play
new_me._use_handlers = self._use_handlers
if self._dep_chain is not None:
new_me._dep_chain = self._dep_chain[:]
new_me._parent = None
if self._parent and not exclude_parent:
new_me._parent = self._parent.copy(exclude_tasks=True)
@ -211,11 +159,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
new_me.rescue = _dupe_task_list(self.rescue or [], new_me)
new_me.always = _dupe_task_list(self.always or [], new_me)
new_me._role = None
if self._role:
new_me._role = self._role
new_me.validate()
return new_me
def set_loader(self, loader):
@ -225,84 +170,9 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
elif self._role:
self._role.set_loader(loader)
dep_chain = self.get_dep_chain()
if dep_chain:
for dep in dep_chain:
for dep in self.get_dep_chain():
dep.set_loader(loader)
def _get_parent_attribute(self, attr, omit=False):
"""
Generic logic to get the attribute or parent attribute for a block value.
"""
fattr = self.fattributes[attr]
extend = fattr.extend
prepend = fattr.prepend
try:
# omit self, and only get parent values
if omit:
value = Sentinel
else:
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
_parent = self._parent
else:
_parent = self._parent._parent
if _parent and (value is Sentinel or extend):
try:
if getattr(_parent, 'statically_loaded', True):
if hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except AttributeError:
pass
if self._role and (value is Sentinel or extend):
try:
parent_value = getattr(self._role, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
dep_chain = self.get_dep_chain()
if dep_chain and (value is Sentinel or extend):
dep_chain.reverse()
for dep in dep_chain:
dep_value = getattr(dep, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, dep_value, prepend)
else:
value = dep_value
if value is not Sentinel and not extend:
break
except AttributeError:
pass
if self._play and (value is Sentinel or extend):
try:
play_value = getattr(self._play, f'_{attr}', Sentinel)
if play_value is not Sentinel:
if extend:
value = self._extend_value(value, play_value, prepend)
else:
value = play_value
except AttributeError:
pass
except KeyError:
pass
return value
def filter_tagged_tasks(self, all_vars):
"""
Creates a new block, with task lists filtered based on the tags.
@ -348,7 +218,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
return evaluate_block(self)
def has_tasks(self):
return len(self.block) > 0 or len(self.rescue) > 0 or len(self.always) > 0
return bool(self.block or self.rescue or self.always)
def get_include_params(self):
if self._parent:
@ -356,21 +226,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
else:
return dict()
def all_parents_static(self):
"""
Determine if all of the parents of this block were statically loaded
or not. Since Task/TaskInclude objects may be in the chain, they simply
call their parents all_parents_static() method. Only Block objects in
the chain check the statically_loaded value of the parent.
"""
from ansible.playbook.task_include import TaskInclude
if self._parent:
if isinstance(self._parent, TaskInclude) and not self._parent.statically_loaded:
return False
return self._parent.all_parents_static()
return True
def get_first_parent_include(self):
from ansible.playbook.task_include import TaskInclude
if self._parent:

@ -5,9 +5,6 @@ from __future__ import annotations
from ansible.playbook.attribute import FieldAttribute
from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.display import Display
display = Display()
def _ensure_default_collection(collection_list=None):
@ -36,7 +33,7 @@ class CollectionSearch:
def _load_collections(self, attr, ds):
# We are always a mixin with Base, so we can validate this untemplated
# field early on to guarantee we are dealing with a list.
ds = self.get_validated_value('collections', self.fattributes.get('collections'), ds, None)
ds = self.get_validated_value('collections', ds, None)
# this will only be called if someone specified a value; call the shared value
_ensure_default_collection(collection_list=ds)

@ -18,9 +18,6 @@
from __future__ import annotations
from ansible.playbook.attribute import FieldAttribute
from ansible.utils.display import Display
display = Display()
class Conditional:
@ -31,9 +28,6 @@ class Conditional:
when = FieldAttribute(isa='list', default=list, extend=True, prepend=True)
def __init__(self, *args, **kwargs):
super().__init__()
def _validate_when(self, attr, name, value):
if not isinstance(value, list):
setattr(self, name, [value])

@ -38,7 +38,7 @@ class Handler(Task):
return "HANDLER: %s" % self.get_name()
def _validate_listen(self, attr, name, value):
new_value = self.get_validated_value(name, attr, value, None)
new_value = self.get_validated_value(name, value, None)
if self._role is not None:
for listener in new_value.copy():
new_value.extend([

@ -17,7 +17,6 @@
from __future__ import annotations
# from ansible.inventory.host import Host
from ansible.playbook.handler import Handler
from ansible.playbook.task_include import TaskInclude

@ -291,7 +291,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h
def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None, loader=None, collection_search_list=None):
"""
Loads and returns a list of RoleInclude objects from the ds list of role definitions
Loads and returns a list of RoleDefinition objects from the ds list of role definitions
:param ds: list of roles to load
:param play: calling Play object
:param current_role_path: path of the owning role, if any
@ -301,14 +301,14 @@ def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None,
:return:
"""
# we import here to prevent a circular dependency with imports
from ansible.playbook.role.include import RoleInclude
from ansible.playbook.role.definition import RoleDefinition
if not isinstance(ds, list):
raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds)))
roles = []
for role_def in ds:
i = RoleInclude.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager,
i = RoleDefinition.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager,
loader=loader, collection_list=collection_search_list)
roles.append(i)

@ -31,9 +31,6 @@ class LoopControl(FieldAttributeBase):
extended_allitems = NonInheritableFieldAttribute(isa='bool', default=True, always_post_validate=True)
break_when = NonInheritableFieldAttribute(isa='list', default=list)
def __init__(self):
super(LoopControl, self).__init__()
@staticmethod
def load(data, variable_manager=None, loader=None):
t = LoopControl()

@ -35,12 +35,9 @@ from ansible.playbook.role import Role
from ansible.playbook.task import Task
from ansible.playbook.taggable import Taggable
from ansible.parsing.vault import EncryptedString
from ansible.utils.display import Display
from ansible._internal._templating._engine import TemplateEngine as _TE
display = Display()
__all__ = ['Play']
@ -73,7 +70,7 @@ class Play(Base, Taggable, CollectionSearch):
validate_argspec = NonInheritableFieldAttribute(isa='string', always_post_validate=True)
# Role Attributes
roles = NonInheritableFieldAttribute(isa='list', default=list, priority=90)
roles = NonInheritableFieldAttribute(isa='list', default=list, priority=-1)
# Block (Task) Lists Attributes
handlers = NonInheritableFieldAttribute(isa='list', default=list, priority=-1)
@ -156,27 +153,11 @@ class Play(Base, Taggable, CollectionSearch):
"""
Adjusts play datastructure to cleanup old/legacy items
"""
if not isinstance(ds, dict):
raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds)))
# The use of 'user' in the Play datastructure was deprecated to
# line up with the same change for Tasks, due to the fact that
# 'user' conflicted with the user module.
if 'user' in ds:
# this should never happen, but error out with a helpful message
# to the user if it does...
if 'remote_user' in ds:
raise AnsibleParserError("both 'user' and 'remote_user' are set for this play. "
"The use of 'user' is deprecated, and should be removed", obj=ds)
ds['remote_user'] = ds['user']
del ds['user']
return super(Play, self).preprocess_data(ds)
# DTFIX-FUTURE: these do nothing but augment the exception message; DRY and nuke
def _load_tasks(self, attr, ds):
"""
Loads a list of blocks from a list which may be mixed tasks/blocks.
@ -185,27 +166,9 @@ class Play(Base, Taggable, CollectionSearch):
try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading tasks.", obj=self._ds) from ex
raise AnsibleParserError(f"A malformed block was encountered while loading {attr}.", obj=self._ds) from ex
def _load_pre_tasks(self, attr, ds):
"""
Loads a list of blocks from a list which may be mixed tasks/blocks.
Bare tasks outside of a block are given an implicit block.
"""
try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading pre_tasks.", obj=self._ds) from ex
def _load_post_tasks(self, attr, ds):
"""
Loads a list of blocks from a list which may be mixed tasks/blocks.
Bare tasks outside of a block are given an implicit block.
"""
try:
return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader)
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading post_tasks.", obj=self._ds) from ex
_load_pre_tasks = _load_post_tasks = _load_tasks
def _load_handlers(self, attr, ds):
"""
@ -213,17 +176,13 @@ class Play(Base, Taggable, CollectionSearch):
Bare handlers outside of a block are given an implicit block.
"""
try:
return self._extend_value(
self.handlers,
load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader),
prepend=True
)
return load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader) + self.handlers
except AssertionError as ex:
raise AnsibleParserError("A malformed block was encountered while loading handlers.", obj=self._ds) from ex
def _load_roles(self, attr, ds):
"""
Loads and returns a list of RoleInclude objects from the datastructure
Loads and returns a list of RoleDefinition objects from the datastructure
list of role definitions and creates the Role from those objects
"""
@ -382,21 +341,6 @@ class Play(Base, Taggable, CollectionSearch):
return [self.vars_files]
return self.vars_files
def get_handlers(self):
return self.handlers[:]
def get_roles(self):
return self.roles[:]
def get_tasks(self):
tasklist = []
for task in self.pre_tasks + self.tasks + self.post_tasks:
if isinstance(task, Block):
tasklist.append(task.block + task.rescue + task.always)
else:
tasklist.append(task)
return tasklist
def copy(self):
new_me = super(Play, self).copy()
new_me.role_cache = self.role_cache.copy()

@ -41,12 +41,6 @@ from ansible.utils.display import Display
from ansible.utils.path import is_subpath
from ansible.utils.vars import combine_vars
# NOTE: This import is only needed for the type-checking in __init__. While there's an alternative
# available by using forward references this seems not to work well with commonly used IDEs.
# Therefore the TYPE_CHECKING hack seems to be a more universal approach, even if not being very elegant.
# References:
# * https://stackoverflow.com/q/39740632/199513
# * https://peps.python.org/pep-0484/#forward-references
if _t.TYPE_CHECKING:
from ansible.playbook.block import Block
from ansible.playbook.play import Play
@ -155,6 +149,8 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
self._completed: dict[str, bool] = dict()
self._should_validate: bool = validate
self._dep_chain: list[Role] | None = None
if from_files is None:
from_files = {}
self._from_files: dict[str, list[str]] = from_files
@ -206,7 +202,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
return self._get_hash_dict() == other._get_hash_dict()
@staticmethod
def load(role_include, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True):
def load(role_definition, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True):
if from_files is None:
from_files = {}
try:
@ -215,7 +211,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
# that role?)
# see https://github.com/ansible/ansible/issues/61527
r = Role(play=play, from_files=from_files, from_include=from_include, validate=validate, public=public, static=static, rescuable=rescuable)
r._load_role_data(role_include, parent_role=parent_role)
r._load_role_data(role_definition, parent_role=parent_role)
role_path = r.get_role_path()
if role_path not in play.role_cache:
@ -230,23 +226,23 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
except RecursionError as ex:
raise AnsibleError("A recursion loop was detected with the roles specified. Make sure child roles do not have dependencies on parent roles",
obj=role_include._ds) from ex
obj=role_definition._ds) from ex
def _load_role_data(self, role_include, parent_role=None):
self._role_name = role_include.role
self._role_path = role_include.get_role_path()
self._role_collection = role_include._role_collection
self._role_params = role_include.get_role_params()
self._variable_manager = role_include.get_variable_manager()
self._loader = role_include.get_loader()
def _load_role_data(self, role_definition, parent_role=None):
self._role_name = role_definition.role
self._role_path = role_definition.get_role_path()
self._role_collection = role_definition._role_collection
self._role_params = role_definition.get_role_params()
self._variable_manager = role_definition.get_variable_manager()
self._loader = role_definition.get_loader()
if parent_role:
self.add_parent(parent_role)
# copy over all field attributes from the RoleInclude
# copy over all field attributes from the RoleDefinition
# update self._attr directly, to avoid squashing
for attr_name in self.fattributes:
setattr(self, f'_{attr_name}', getattr(role_include, f'_{attr_name}', Sentinel))
setattr(self, f'_{attr_name}', getattr(role_definition, f'_{attr_name}', Sentinel))
# vars and default vars are regular dictionaries
self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True)
@ -466,14 +462,12 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
"""
deps = []
for role_include in self._metadata.dependencies:
r = Role.load(role_include, play=self._play, parent_role=self, static=self.static)
for role_definition in self._metadata.dependencies:
r = Role.load(role_definition, play=self._play, parent_role=self, static=self.static)
deps.append(r)
return deps
# other functions
def add_parent(self, parent_role):
""" adds a role to the list of this roles parents """
if not isinstance(parent_role, Role):
@ -485,12 +479,15 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
def get_parents(self):
return self._parents
def get_dep_chain(self):
def get_dep_chain(self) -> list[Role]:
"""Returns a copy of the parent chain list."""
if self._dep_chain is None:
dep_chain = []
for parent in self._parents:
dep_chain.extend(parent.get_dep_chain())
dep_chain.append(parent)
return dep_chain
self._dep_chain = dep_chain
return self._dep_chain[:]
def get_default_vars(self, dep_chain=None):
dep_chain = [] if dep_chain is None else dep_chain
@ -529,8 +526,6 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
def get_vars(self, dep_chain=None, include_params=True, only_exports=False):
dep_chain = [] if dep_chain is None else dep_chain
all_vars = {}
# get role_vars: from parent objects
# TODO: is this right precedence for inherited role_vars?
all_vars = self.get_inherited_vars(dep_chain, only_exports=only_exports)
@ -586,23 +581,17 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
#
# ``get_handler_blocks`` may be called when handling ``import_role`` during parsing
# as well as with ``Play.compile_roles_handlers`` from ``TaskExecutor``
# FIXME deprecate unused dep_chain parameter
if self._compiled_handler_blocks:
return self._compiled_handler_blocks
self._compiled_handler_blocks = block_list = []
# update the dependency chain here
if dep_chain is None:
dep_chain = []
new_dep_chain = dep_chain + [self]
for dep in self.get_direct_dependencies():
dep_blocks = dep.get_handler_blocks(play=play, dep_chain=new_dep_chain)
block_list.extend(dep_blocks)
block_list.extend(dep.get_handler_blocks(play=play))
for task_block in self._handler_blocks:
new_task_block = task_block.copy()
new_task_block._dep_chain = new_dep_chain
new_task_block._play = play
block_list.append(new_task_block)
@ -626,24 +615,18 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
with each task, so tasks know by which route they were found, and
can correctly take their parent's tags/conditionals into account.
"""
# FIXME deprecate unused dep_chain parameter
from ansible.playbook.block import Block
from ansible.playbook.task import Task
block_list = []
# update the dependency chain here
if dep_chain is None:
dep_chain = []
new_dep_chain = dep_chain + [self]
deps = self.get_direct_dependencies()
for dep in deps:
dep_blocks = dep.compile(play=play, dep_chain=new_dep_chain)
for dep in self.get_direct_dependencies():
dep_blocks = dep.compile(play=play)
block_list.extend(dep_blocks)
for task_block in self._task_blocks:
new_task_block = task_block.copy()
new_task_block._dep_chain = new_dep_chain
new_task_block._play = play
block_list.append(new_task_block)

@ -20,12 +20,13 @@ from __future__ import annotations
import os
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.errors import AnsibleError, AnsibleAssertionError, AnsibleParserError
from ansible.module_utils._internal._datatag import AnsibleTagHelper
from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.conditional import Conditional
from ansible.playbook.delegatable import Delegatable
from ansible.playbook.taggable import Taggable
from ansible._internal._templating._engine import TemplateEngine
from ansible.utils.collection_loader import AnsibleCollectionRef
@ -38,13 +39,12 @@ __all__ = ['RoleDefinition']
display = Display()
class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
class RoleDefinition(Base, Conditional, Taggable, Delegatable, CollectionSearch):
role = NonInheritableFieldAttribute(isa='string')
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super(RoleDefinition, self).__init__()
super().__init__()
self._play = play
self._variable_manager = variable_manager
@ -56,12 +56,16 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
self._role_params = dict()
self._collection_list = collection_list
# def __repr__(self):
# return 'ROLEDEF: ' + self._attributes.get('role', '<no name set>')
@staticmethod
def load(data, variable_manager=None, loader=None):
raise AnsibleError("not implemented")
def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None):
if not (isinstance(data, str) or isinstance(data, dict)):
raise AnsibleParserError("Invalid role definition.", obj=data)
if isinstance(data, str) and ',' in data:
raise AnsibleError("Invalid old style role requirement: %s" % data)
rd = RoleDefinition(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list)
return rd.load_data(data, variable_manager=variable_manager, loader=loader)
def preprocess_data(self, ds):
# role names that are simply numbers can be parsed by PyYAML

@ -1,49 +0,0 @@
# (c) 2014 Michael DeHaan, <michael@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
from ansible.errors import AnsibleError, AnsibleParserError
from ansible.playbook.delegatable import Delegatable
from ansible.playbook.role.definition import RoleDefinition
__all__ = ['RoleInclude']
class RoleInclude(RoleDefinition, Delegatable):
"""
A derivative of RoleDefinition, used by playbook code when a role
is included for execution in a play.
"""
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager,
loader=loader, collection_list=collection_list)
@staticmethod
def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None):
if not (isinstance(data, str) or isinstance(data, dict)):
raise AnsibleParserError("Invalid role definition.", obj=data)
if isinstance(data, str) and ',' in data:
raise AnsibleError("Invalid old style role requirement: %s" % data)
ri = RoleInclude(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list)
return ri.load_data(data, variable_manager=variable_manager, loader=loader)

@ -21,7 +21,7 @@ import os
from ansible.errors import AnsibleParserError, AnsibleError
from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.base import FieldAttributeBase
from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.helpers import load_list_of_roles
from ansible.playbook.role.requirement import RoleRequirement
@ -29,7 +29,7 @@ from ansible.playbook.role.requirement import RoleRequirement
__all__ = ['RoleMetadata']
class RoleMetadata(Base, CollectionSearch):
class RoleMetadata(FieldAttributeBase, CollectionSearch):
"""
This class wraps the parsing and validation of the optional metadata
within each Role (meta/main.yml).
@ -59,7 +59,7 @@ class RoleMetadata(Base, CollectionSearch):
def _load_dependencies(self, attr, ds):
"""
This is a helper loading function for the dependencies list,
which returns a list of RoleInclude objects
which returns a list of RoleDefinition objects
"""
roles = []

@ -18,8 +18,6 @@
from __future__ import annotations
from ansible.errors import AnsibleError
from ansible.playbook.role.definition import RoleDefinition
from ansible.utils.display import Display
from ansible.utils.galaxy import scm_archive_resource
__all__ = ['RoleRequirement']
@ -32,19 +30,13 @@ VALID_SPEC_KEYS = [
'version',
]
display = Display()
class RoleRequirement(RoleDefinition):
class RoleRequirement:
"""
Helper class for Galaxy, which is used to parse both dependencies
specified in meta/main.yml and requirements.yml files.
"""
def __init__(self):
pass
@staticmethod
def repo_url_to_role_name(repo_url):
# gets the role name out of a repo like

@ -21,14 +21,11 @@ from ansible.errors import AnsibleError, AnsibleParserError
from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role import Role
from ansible.playbook.role.include import RoleInclude
from ansible.utils.display import Display
from ansible.playbook.role.definition import RoleDefinition
from ansible._internal._templating._engine import TemplateEngine
__all__ = ['IncludeRole']
display = Display()
class IncludeRole(TaskInclude):
@ -59,6 +56,11 @@ class IncludeRole(TaskInclude):
self._parent_role = role
self._role_name = None
self._role_path = None
self.statically_loaded = False
@property
def _post_validate_object(self):
return not self.statically_loaded
def get_name(self):
""" return the name of the task """
@ -73,13 +75,13 @@ class IncludeRole(TaskInclude):
myplay = play
try:
ri = RoleInclude.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections)
rd = RoleDefinition.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections)
except AnsibleError as e:
if not self.rescuable:
raise AnsibleParserError("Could not include role.") from e
raise
ri.vars |= self.vars
rd.vars |= self.vars
if variable_manager is not None:
available_variables = variable_manager.get_vars(play=myplay, task=self)
@ -89,7 +91,7 @@ class IncludeRole(TaskInclude):
from_files = templar.template(self._from_files)
# build role
actual_role = Role.load(ri, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True,
actual_role = Role.load(rd, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True,
validate=self.rolespec_validate, public=self.public, static=self.statically_loaded, rescuable=self.rescuable)
actual_role._metadata.allow_duplicates = self.allow_duplicates
@ -99,23 +101,19 @@ class IncludeRole(TaskInclude):
# save this for later use
self._role_path = actual_role._role_path
# compile role with parent roles as dependencies to ensure they inherit
# variables
dep_chain = actual_role.get_dep_chain()
p_block = self.build_parent_block()
# collections value is not inherited; override with the value we calculated during role setup
p_block.collections = actual_role.collections
blocks = actual_role.compile(play=myplay, dep_chain=dep_chain)
blocks = actual_role.compile(play=myplay)
for b in blocks:
b._parent = p_block
# HACK: parent inheritance doesn't seem to have a way to handle this intermediate override until squashed/finalized
b.collections = actual_role.collections
# updated available handlers in play
handlers = actual_role.get_handler_blocks(play=myplay, dep_chain=dep_chain)
handlers = actual_role.get_handler_blocks(play=myplay)
for h in handlers:
h._parent = p_block
myplay.handlers = myplay.handlers + handlers

@ -71,7 +71,7 @@ class Taggable:
obj = obj._parent
yield self.get_play()
yield self.play
def evaluate_tags(self, only_tags, skip_tags, all_vars):
"""Check if the current item should be executed depending on the specified tags.

@ -20,7 +20,6 @@ from __future__ import annotations
import typing as t
from ansible import constants as C
from ansible.module_utils.common.sentinel import Sentinel
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError, AnsibleValueOmittedError
from ansible.executor.module_common import _get_action_arg_defaults
from ansible.module_utils.common.text.converters import to_native
@ -30,7 +29,6 @@ from ansible.plugins.action import ActionBase
from ansible.plugins.loader import action_loader, module_loader, lookup_loader
from ansible.playbook.attribute import NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.block import Block
from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.conditional import Conditional
from ansible.playbook.delegatable import Delegatable
@ -97,7 +95,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
""" constructors a task, without the Task.load classmethod, it will be pretty blank """
self._role = role
self._parent = None
self.implicit = False
self._resolved_action: str | None = None
@ -156,20 +153,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
else:
return "%s" % (self.action,)
def _merge_kv(self, ds):
if ds is None:
return ""
elif isinstance(ds, str):
return ds
elif isinstance(ds, dict):
buf = ""
for (k, v) in ds.items():
if k.startswith('_'):
continue
buf = buf + "%s=%s " % (k, v)
buf = buf.strip()
return buf
@staticmethod
def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None):
task = Task(block=block, role=role, task_include=task_include)
@ -283,7 +266,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
else:
# Validate this untemplated field early on to guarantee we are dealing with a list.
# This is also done in CollectionSearch._load_collections() but this runs before that call.
collections_list = self.get_validated_value('collections', self.fattributes.get('collections'), collections_list, None)
collections_list = self.get_validated_value('collections', collections_list, None)
if default_collection and not self._role: # FIXME: and not a collections role
if collections_list:
@ -375,17 +358,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
except Exception as ex:
raise AnsibleParserError("Invalid 'register' specified.", obj=value) from ex
def post_validate(self, templar):
"""
Override of base class post_validate, to also do final validation on
the block and task include (if any) to which this task belongs.
"""
if self._parent:
self._parent.post_validate(templar)
super(Task, self).post_validate(templar)
def _post_validate_loop(self, attr, value, templar):
"""
Override post validation for the loop field, which is templated
@ -425,7 +397,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
return
raise
# NB: the environment FieldAttribute definition ensures that value is always a list
for env_item in value:
if isinstance(env_item, dict):
for k in env_item:
@ -471,10 +442,8 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
all_vars |= self.vars
if 'tags' in all_vars:
del all_vars['tags']
if 'when' in all_vars:
del all_vars['when']
all_vars.pop('tags', None)
all_vars.pop('when', None)
return all_vars
@ -482,8 +451,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
all_vars = dict()
if self._parent:
all_vars |= self._parent.get_include_params()
if self.action in C._ACTION_ALL_INCLUDES:
all_vars |= self.vars
return all_vars
def copy(self, exclude_parent: bool = False, exclude_tasks: bool = False) -> Task:
@ -493,13 +460,10 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
if self._parent and not exclude_parent:
new_me._parent = self._parent.copy(exclude_tasks=exclude_tasks)
new_me._role = None
if self._role:
new_me._role = self._role
new_me.implicit = self.implicit
new_me._resolved_action = self._resolved_action
new_me._uuid = self._uuid
return new_me
@ -515,51 +479,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
if self._parent:
self._parent.set_loader(loader)
def _get_parent_attribute(self, attr, omit=False):
"""
Generic logic to get the attribute or parent attribute for a task value.
"""
fattr = self.fattributes[attr]
extend = fattr.extend
prepend = fattr.prepend
try:
# omit self, and only get parent values
if omit:
value = Sentinel
else:
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
_parent = self._parent
else:
_parent = self._parent._parent
if _parent and (value is Sentinel or extend):
if getattr(_parent, 'statically_loaded', True):
# vars are always inheritable, other attributes might not be for the parent but still should be for other ancestors
if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except KeyError:
pass
return value
def all_parents_static(self):
if self._parent:
return self._parent.all_parents_static()
return True
def get_first_parent_include(self):
from ansible.playbook.task_include import TaskInclude
if self._parent:
@ -568,12 +487,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
return self._parent.get_first_parent_include()
return None
def get_play(self):
parent = self._parent
while not isinstance(parent, Block):
parent = parent._parent
return parent._play
def dump_attrs(self):
"""Override to smuggle important non-FieldAttribute values back to the controller."""
attrs = super().dump_attrs()
@ -585,10 +498,9 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
# from_attrs is only used to create a finalized task
# from attrs from the Worker/TaskExecutor
# Those attrs are finalized and squashed in the TE
# Those attrs are finalized in the TE
# and controller side use needs to reflect that
self._finalized = True
self._squashed = True
def _resolve_conditional(
self,

@ -123,3 +123,8 @@ class TaskInclude(Task):
p_block = self
return p_block
def get_include_params(self):
v = super().get_include_params()
v |= self.vars
return v

@ -542,7 +542,7 @@ class VariableManager:
delegated_vars['ansible_delegated_vars'] = {
delegated_host_name: self.get_vars(
play=task.get_play(),
play=task.play,
host=delegated_host,
task=task,
include_hostvars=True,

@ -1,6 +0,0 @@
- hosts: localhost
remote_user: a
user: b
tasks:
- debug:
msg: did not run

@ -8,20 +8,10 @@ ansible-playbook -i ../../inventory types.yml -v "$@"
# test timeout
ansible-playbook -i ../../inventory timeout.yml -v "$@"
# our Play class allows for 'user' or 'remote_user', but not both.
# first test that both user and remote_user work individually
set +e
result="$(ansible-playbook -i ../../inventory user.yml -v "$@" 2>&1)"
set -e
grep -q "worked with user" <<< "$result"
grep -q "worked with remote_user" <<< "$result"
# then test that the play errors if user and remote_user both exist
echo "EXPECTED ERROR: Ensure we fail properly if a play has both user and remote_user."
set +e
result="$(ansible-playbook -i ../../inventory remote_user_and_user.yml -v "$@" 2>&1)"
set -e
grep -q "both 'user' and 'remote_user' are set for this play." <<< "$result"
grep -q "is not a valid attribute for a Play" <<< "$result"
# test that playbook errors if len(plays) == 0
echo "EXPECTED ERROR: Ensure we fail properly if a playbook is an empty list."

@ -13,11 +13,6 @@
- hosts: localhost
user: "{{ me }}"
tasks:
- debug:
- name: should not happen
debug:
msg: worked with user ({{ me }})
- hosts: localhost
remote_user: "{{ me }}"
tasks:
- debug:
msg: worked with remote_user ({{ me }})

@ -50,7 +50,6 @@ lib/ansible/module_utils/six/__init__.py pylint:trailing-comma-tuple
lib/ansible/module_utils/six/__init__.py pylint:unidiomatic-typecheck
lib/ansible/module_utils/six/__init__.py replace-urlopen
lib/ansible/module_utils/urls.py replace-urlopen
lib/ansible/playbook/role/include.py pylint:arguments-renamed
lib/ansible/plugins/action/normal.py action-plugin-docs # default action plugin for modules without a dedicated action plugin
lib/ansible/plugins/cache/base.py ansible-doc!skip # not a plugin, but a stub for backwards compatibility
lib/ansible/plugins/callback/__init__.py pylint:arguments-renamed

@ -31,7 +31,7 @@ from units.mock.loader import DictDataLoader
from units.mock.path import mock_unfrackpath_noop
from ansible.playbook.role import Role
from ansible.playbook.role.include import RoleInclude
from ansible.playbook.role.definition import RoleDefinition
from ansible.playbook.role import hash_params
@ -168,7 +168,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(str(r), 'foo_tasks')
@ -190,7 +190,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play, from_files=dict(tasks='custom_main'))
self.assertEqual(r._task_blocks[0]._ds[0]['command'], 'baz')
@ -208,7 +208,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_handlers', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_handlers', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(len(r._handler_blocks), 1)
@ -229,7 +229,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar'))
@ -250,7 +250,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar'))
@ -271,7 +271,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar'))
@ -294,7 +294,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r._default_vars, dict(foo='bar', a=1, b=2))
@ -314,7 +314,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r._role_vars, dict(foo='bam'))
@ -361,7 +361,7 @@ class TestRole(unittest.TestCase):
mock_play.collections = None
mock_play.role_cache = {}
i = RoleInclude.load('foo_metadata', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('foo_metadata', play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
role_deps = r.get_direct_dependencies()
@ -379,16 +379,16 @@ class TestRole(unittest.TestCase):
self.assertEqual(all_deps[1].get_name(), 'baz_metadata')
self.assertEqual(all_deps[2].get_name(), 'bar_metadata')
i = RoleInclude.load('bad1_metadata', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('bad1_metadata', play=mock_play, loader=fake_loader)
self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play)
i = RoleInclude.load('bad2_metadata', play=mock_play, loader=fake_loader)
i = RoleDefinition.load('bad2_metadata', play=mock_play, loader=fake_loader)
self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play)
# TODO: re-enable this test once Ansible has proper role dep cycle detection
# that doesn't rely on stack overflows being recoverable (as they aren't in Py3.7+)
# see https://github.com/ansible/ansible/issues/61527
# i = RoleInclude.load('recursive1_metadata', play=mock_play, loader=fake_loader)
# i = RoleDefinition.load('recursive1_metadata', play=mock_play, loader=fake_loader)
# self.assertRaises(AnsibleError, Role.load, i, play=mock_play)
@patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop)
@ -406,7 +406,7 @@ class TestRole(unittest.TestCase):
mock_play = MagicMock()
mock_play.role_cache = {}
i = RoleInclude.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader)
i = RoleDefinition.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader)
r = Role.load(i, play=mock_play)
self.assertEqual(r.get_name(), "foo_complex")

@ -27,30 +27,7 @@ class TestAttribute(unittest.TestCase):
self.one = Attribute(priority=100)
self.two = Attribute(priority=0)
def test_eq(self):
self.assertTrue(self.one == self.one)
self.assertFalse(self.one == self.two)
def test_ne(self):
self.assertFalse(self.one != self.one)
self.assertTrue(self.one != self.two)
def test_lt(self):
self.assertFalse(self.one < self.one)
self.assertTrue(self.one < self.two)
self.assertFalse(self.two < self.one)
def test_gt(self):
self.assertFalse(self.one > self.one)
self.assertFalse(self.one > self.two)
self.assertTrue(self.two > self.one)
def test_le(self):
self.assertTrue(self.one <= self.one)
self.assertTrue(self.one <= self.two)
self.assertFalse(self.two <= self.one)
def test_ge(self):
self.assertTrue(self.one >= self.one)
self.assertFalse(self.one >= self.two)
self.assertTrue(self.two >= self.one)

@ -47,8 +47,6 @@ class TestBase(unittest.TestCase):
bsc = self.ClassUnderTest()
parent = ExampleParentBaseSubClass()
bsc._parent = parent
bsc._dep_chain = [parent]
parent._dep_chain = None
bsc.load_data(ds)
fake_loader = DictDataLoader({})
templar = TemplateEngine(loader=fake_loader)
@ -110,11 +108,8 @@ class TestBase(unittest.TestCase):
def test_load_data_invalid_attr_type(self):
ds = {'environment': True}
# environment is supposed to be a list. This
# seems like it shouldn't work?
ret = self.b.load_data(ds)
self.assertEqual(True, ret._environment)
self.assertEqual([True], ret._environment)
def test_post_validate(self):
ds = {'environment': [],
@ -170,10 +165,6 @@ class TestBase(unittest.TestCase):
b = self._base_validate(ds)
self.assertEqual(b.vars, {})
def test_validate_empty(self):
self.b.validate()
self.assertTrue(self.b._validated)
def test_getters(self):
# not sure why these exist, but here are tests anyway
loader = self.b.get_loader()
@ -182,70 +173,6 @@ class TestBase(unittest.TestCase):
self.assertEqual(variable_manager, self.b._variable_manager)
class TestExtendValue(unittest.TestCase):
# _extend_value could be a module or staticmethod but since its
# not, the test is here.
def test_extend_value_list_newlist(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(value_list, new_value_list)
self.assertEqual(value_list + new_value_list, ret)
def test_extend_value_list_newlist_prepend(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret_prepend = b._extend_value(value_list, new_value_list, prepend=True)
self.assertEqual(new_value_list + value_list, ret_prepend)
def test_extend_value_newlist_list(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(new_value_list, value_list)
self.assertEqual(new_value_list + value_list, ret)
def test_extend_value_newlist_list_prepend(self):
b = base.Base()
value_list = ['first', 'second']
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(new_value_list, value_list, prepend=True)
self.assertEqual(value_list + new_value_list, ret)
def test_extend_value_string_newlist(self):
b = base.Base()
some_string = 'some string'
new_value_list = ['new_first', 'new_second']
ret = b._extend_value(some_string, new_value_list)
self.assertEqual([some_string] + new_value_list, ret)
def test_extend_value_string_newstring(self):
b = base.Base()
some_string = 'some string'
new_value_string = 'this is the new values'
ret = b._extend_value(some_string, new_value_string)
self.assertEqual([some_string, new_value_string], ret)
def test_extend_value_list_newstring(self):
b = base.Base()
value_list = ['first', 'second']
new_value_string = 'this is the new values'
ret = b._extend_value(value_list, new_value_string)
self.assertEqual(value_list + [new_value_string], ret)
def test_extend_value_none_none(self):
b = base.Base()
ret = b._extend_value(None, None)
self.assertEqual(len(ret), 0)
self.assertFalse(ret)
def test_extend_value_none_list(self):
b = base.Base()
ret = b._extend_value(None, ['foo'])
self.assertEqual(ret, ['foo'])
class ExampleException(Exception):
pass
@ -255,12 +182,7 @@ class ExampleParentBaseSubClass(base.Base):
test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing')
def __init__(self):
super(ExampleParentBaseSubClass, self).__init__()
self._dep_chain = None
def get_dep_chain(self):
return self._dep_chain
class ExampleSubClass(base.Base):
@ -281,7 +203,7 @@ class BaseSubClass(base.Base):
test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True)
test_attr_list_required = FieldAttribute(isa='list', listof=(str,), required=True,
default=list, always_post_validate=True)
test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value')
test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value', always_post_validate=True)
test_attr_string_required = FieldAttribute(isa='string', required=True,
default='the_test_attr_string_default_value')
test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True)
@ -299,9 +221,6 @@ class BaseSubClass(base.Base):
test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter',
always_post_validate=True)
def _get_attr_test_attr_method(self):
return 'foo bar'
def _validate_test_attr_example(self, attr, name, value):
if not isinstance(value, str):
raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value)))
@ -333,13 +252,6 @@ class TestBaseSubClass(TestBase):
bsc = self._base_validate(ds)
self.assertEqual(bsc.test_attr_int, MOST_RANDOM_NUMBER)
def test_attr_int_del(self):
MOST_RANDOM_NUMBER = 37
ds = {'test_attr_int': MOST_RANDOM_NUMBER}
bsc = self._base_validate(ds)
del bsc.test_attr_int
self.assertNotIn('_test_attr_int', bsc.__dict__)
def test_attr_float(self):
roughly_pi = 4.0
ds = {'test_attr_float': roughly_pi}
@ -446,7 +358,7 @@ class TestBaseSubClass(TestBase):
def test_attr_string_invalid_list(self):
ds = {'test_attr_string': ['The new test_attr_string', 'value, however in a list']}
self.assertRaises(AnsibleParserError, self._base_validate, ds)
self.assertRaises(AnsibleFieldAttributeError, self._base_validate, ds)
def test_attr_string_required(self):
the_string_value = "the new test_attr_string_required_value"
@ -512,12 +424,6 @@ class TestBaseSubClass(TestBase):
{'test_attr_unknown_isa': True}
)
def test_attr_method(self):
ds = {'test_attr_method': 'value from the ds'}
bsc = self._base_validate(ds)
# The value returned by the subclasses _get_attr_test_attr_method
self.assertEqual(bsc.test_attr_method, 'foo bar')
def test_attr_method_missing(self):
a_string = 'The value set from the ds'
ds = {'test_attr_method_missing': a_string}
@ -525,10 +431,9 @@ class TestBaseSubClass(TestBase):
self.assertEqual(bsc.test_attr_method_missing, a_string)
def test_get_validated_value_string_preserve_tags(self):
attribute = FieldAttribute(isa='string')
value = TrustedAsTemplate().tag('bar')
templar = TemplateEngine(None)
bsc = self.ClassUnderTest()
result = bsc.get_validated_value('foo', attribute, value, templar)
result = bsc.get_validated_value('test_attr_string', value, templar)
assert TrustedAsTemplate.is_tagged_on(result)
assert result == 'bar'

@ -32,7 +32,7 @@ from ansible.playbook.block import Block
from ansible.playbook.handler import Handler
from ansible.playbook.task import Task
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role.include import RoleInclude
from ansible.playbook.role.definition import RoleDefinition
class MixinForMocks(object):
@ -317,7 +317,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks):
variable_manager=self.mock_variable_manager, loader=self.fake_role_loader)
self.assertIsInstance(res, list)
for r in res:
self.assertIsInstance(r, RoleInclude)
self.assertIsInstance(r, RoleDefinition)
def test_block_unknown_action(self):
ds = [{
@ -328,7 +328,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks):
variable_manager=self.mock_variable_manager, loader=self.fake_role_loader)
self.assertIsInstance(res, list)
for r in res:
self.assertIsInstance(r, RoleInclude)
self.assertIsInstance(r, RoleDefinition)
@pytest.mark.usefixtures('collection_loader')

@ -50,17 +50,6 @@ def test_basic_play():
assert p.connection == 'local'
def test_play_with_remote_user():
p = Play.load(dict(
name="test play",
hosts=['foo'],
user="testing",
gather_facts=False,
))
assert p.remote_user == "testing"
def test_play_with_user_conflict():
play_data = dict(
name="test play",
@ -101,7 +90,6 @@ def test_play_with_handlers():
))
assert len(p.handlers) >= 1
assert len(p.get_handlers()) >= 1
assert isinstance(p.handlers[0], Block)
assert p.handlers[0].has_tasks() is True
@ -118,9 +106,8 @@ def test_play_with_pre_tasks():
assert isinstance(p.pre_tasks[0], Block)
assert p.pre_tasks[0].has_tasks() is True
assert len(p.get_tasks()) >= 1
assert isinstance(p.get_tasks()[0][0], Task)
assert p.get_tasks()[0][0].action == 'shell'
assert isinstance(p.pre_tasks[0].block[0], Task)
assert p.pre_tasks[0].block[0].action == 'shell'
def test_play_with_post_tasks():
@ -158,7 +145,7 @@ def test_play_with_roles(mocker):
blocks = p.compile()
assert len(blocks) > 1
assert all(isinstance(block, Block) for block in blocks)
assert isinstance(p.get_roles()[0], Role)
assert isinstance(p.roles[0], Role)
def test_play_compile():

@ -28,9 +28,10 @@ class TaggableTestObj(Taggable):
self._loader = DictDataLoader({})
self.tags = []
self._parent = None
self.play = None
def get_play(self):
return None
def finalized(self):
return False
class TestTaggable(unittest.TestCase):

@ -56,7 +56,7 @@ class TestTask(unittest.TestCase):
p = dict(delay=delay)
p.update(task_base)
t = Task().load_data(p)
self.assertEqual(t.get_validated_value('delay', t.fattributes.get('delay'), delay, None), expected)
self.assertEqual(t.get_validated_value('delay', delay, None), expected)
bad_params = [
'E',
@ -69,7 +69,7 @@ class TestTask(unittest.TestCase):
p.update(task_base)
t = Task().load_data(p)
with self.assertRaises(AnsibleError):
dummy = t.get_validated_value('delay', t.fattributes.get('delay'), delay, None)
dummy = t.get_validated_value('delay', delay, None)
def test_task_auto_name_with_role(self):
pass

@ -84,7 +84,6 @@ class TestVariableManager(unittest.TestCase):
mock_play = MagicMock()
mock_play.get_vars.return_value = dict(foo="bar")
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = []
mock_inventory = MagicMock()
@ -103,7 +102,6 @@ class TestVariableManager(unittest.TestCase):
mock_play = MagicMock()
mock_play.get_vars.return_value = dict()
mock_play.get_roles.return_value = []
mock_play.get_vars_files.return_value = [__file__]
mock_inventory = MagicMock()
@ -158,10 +156,25 @@ class TestVariableManager(unittest.TestCase):
# and role2 depend on common-role. Check that the tasks see
# different values of role_var.
blocks = play1.compile()
# compile returns the following layout of blocks
# (fact gathering is missing as that is added by PlayIterator):
# [TASK: meta (flush_handlers)]
# [TASK: common-role : debug]
# [TASK: meta (role_complete)]
# [TASK: meta (role_complete)]
# [TASK: common-role : debug]
# [TASK: meta (role_complete)]
# [TASK: meta (role_complete)]
# [TASK: meta (flush_handlers)]
# [TASK: meta (flush_handlers)]
task = blocks[1].block[0]
assert task.action == 'debug'
res = v.get_vars(play=play1, task=task)
self.assertEqual(res['role_var'], 'role_var_from_role1')
assert res['role_var'] == 'role_var_from_role1'
task = blocks[2].block[0]
task = blocks[4].block[0]
assert task.action == 'debug'
res = v.get_vars(play=play1, task=task)
self.assertEqual(res['role_var'], 'role_var_from_role2')
assert res['role_var'] == 'role_var_from_role2'

Loading…
Cancel
Save