diff --git a/changelogs/fragments/simplify-dep_chain.yml b/changelogs/fragments/simplify-dep_chain.yml new file mode 100644 index 00000000000..17109108da2 --- /dev/null +++ b/changelogs/fragments/simplify-dep_chain.yml @@ -0,0 +1,2 @@ +minor_changes: + - internals - simplify ``dep_chain`` handling in playbook objects. diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index a9fa2c22110..d06e9c5e1d9 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -77,8 +77,6 @@ class TaskExecutor: self._loop_eval_error = None self._task_templar = TemplateEngine(loader=self._loader, variables=self._job_vars) - self._task.squash() - def run(self): """ The main executor entrypoint, where we determine if the specified @@ -369,7 +367,6 @@ class TaskExecutor: if self._task.loop_control and self._task.loop_control.break_when: break_when = self._task.loop_control.get_validated_value( 'break_when', - self._task.loop_control.fattributes.get('break_when'), self._task.loop_control.break_when, templar, ) diff --git a/lib/ansible/playbook/attribute.py b/lib/ansible/playbook/attribute.py index 3dbbef555ba..b055724aa20 100644 --- a/lib/ansible/playbook/attribute.py +++ b/lib/ansible/playbook/attribute.py @@ -90,41 +90,13 @@ class Attribute: def __set_name__(self, owner, name): self.name = name - def __eq__(self, other): - return other.priority == self.priority - - def __ne__(self, other): - return other.priority != self.priority - # NB: higher priority numbers sort first - + # __lt__ is sufficient for sorted() which is our only use case def __lt__(self, other): return other.priority < self.priority - def __gt__(self, other): - return other.priority > self.priority - - def __le__(self, other): - return other.priority <= self.priority - - def __ge__(self, other): - return other.priority >= self.priority - def __get__(self, obj: FieldAttributeBase, obj_type=None): - method = f'_get_attr_{self.name}' - if hasattr(obj, method): - # NOTE this appears to be not used in the codebase, - # _get_attr_connection has been replaced by ConnectionFieldAttribute. - # Leaving it here for test_attr_method from - # test/units/playbook/test_base.py to pass and for backwards compat. - if getattr(obj, '_squashed', False): - value = getattr(obj, f'_{self.name}', Sentinel) - else: - value = getattr(obj, method)() - else: - value = getattr(obj, f'_{self.name}', Sentinel) - - if value is Sentinel: + if (value := getattr(obj, f'_{self.name}', Sentinel)) is Sentinel: value = self.default if callable(value): value = value() @@ -137,17 +109,30 @@ class Attribute: if self.alias is not None: setattr(obj, f'_{self.alias}', value) - # NOTE this appears to be not needed in the codebase, - # leaving it here for test_attr_int_del from - # test/units/playbook/test_base.py to pass. - def __delete__(self, obj): - delattr(obj, f'_{self.name}') - class NonInheritableFieldAttribute(Attribute): ... +def _get_parent_static_chain(obj): + # NOTE similar code in Taggable + # FIXME this encapsulates the mess caused by not having proper parent chain on playbook objects + parent = getattr(obj, '_parent', None) + while parent: + # If parent is static, we can grab attrs from the parent + # otherwise, defer to the grandparent + if getattr(parent, 'statically_loaded', True): + yield parent + parent = getattr(parent, '_parent', None) + + if role := getattr(obj, '_role', None): + yield obj._role + if dep_chain := obj.get_dep_chain(): + yield from reversed(dep_chain) + + yield obj.play + + class FieldAttribute(Attribute): def __init__(self, extend=False, prepend=False, **kwargs): super().__init__(**kwargs) @@ -156,24 +141,27 @@ class FieldAttribute(Attribute): self.prepend = prepend def __get__(self, obj, obj_type=None): - if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False): - value = getattr(obj, f'_{self.name}', Sentinel) - else: - try: - value = obj._get_parent_attribute(self.name) - except AttributeError: - method = f'_get_attr_{self.name}' - if hasattr(obj, method): - # NOTE this appears to be not needed in the codebase, - # _get_attr_connection has been replaced by ConnectionFieldAttribute. - # Leaving it here for test_attr_method from - # test/units/playbook/test_base.py to pass and for backwards compat. - if getattr(obj, '_squashed', False): - value = getattr(obj, f'_{self.name}', Sentinel) - else: - value = getattr(obj, method)() - else: - value = getattr(obj, f'_{self.name}', Sentinel) + value = getattr(obj, f'_{self.name}', Sentinel) + if not obj.finalized: + if self.extend: + # NOTE implements the historic behavior of now removed Base._extend_value + all_values = value if value not in (Sentinel, None) else [] + _extend = all_values.extend + for parent in _get_parent_static_chain(obj): + if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) not in (Sentinel, None): + if self.prepend: + all_values[:0] = parent_value + else: + _extend(parent_value) + + # some FAs contain non-hashable values, skip dedup in such a case + # FIXME deal with list of dicts differently? + value = all_values if self.name in ('module_defaults', 'environment') else list(dict.fromkeys(all_values)) + elif value is Sentinel: + for parent in _get_parent_static_chain(obj): + if (parent_value := getattr(parent, f'_{self.name}', Sentinel)) is not Sentinel: + value = parent_value + break if value is Sentinel: value = self.default diff --git a/lib/ansible/playbook/base.py b/lib/ansible/playbook/base.py index e1b2d4af12a..30edc79e0fc 100644 --- a/lib/ansible/playbook/base.py +++ b/lib/ansible/playbook/base.py @@ -5,7 +5,6 @@ from __future__ import annotations import decimal -import itertools import operator import os @@ -30,6 +29,9 @@ from ansible.utils.display import Display from ansible.utils.vars import combine_vars, get_unique_id, validate_variable_name from ansible._internal._templating._engine import TemplateEngine +if t.TYPE_CHECKING: + from ansible.playbook.role import Role + display = Display() @@ -112,13 +114,13 @@ class FieldAttributeBase: self._origin: Origin | None = None # other internal params - self._validated = False - self._squashed = False self._finalized = False # every object gets a random uuid: self._uuid = get_unique_id() + self._ds = None + @property def finalized(self): return self._finalized @@ -130,10 +132,8 @@ class FieldAttributeBase: display.debug("%s- %s (%s, id=%s)" % (" " * depth, self.__class__.__name__, self, id(self))) if hasattr(self, '_parent') and self._parent: self._parent.dump_me(depth + 2) - dep_chain = self._parent.get_dep_chain() - if dep_chain: - for dep in dep_chain: - dep.dump_me(depth + 2) + for dep in self._parent.get_dep_chain(): + dep.dump_me(depth + 2) if hasattr(self, '_play') and self._play: self._play.dump_me(depth + 2) @@ -148,7 +148,7 @@ class FieldAttributeBase: raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds) # cache the datastructure internally - setattr(self, '_ds', ds) + self._ds = ds # the variable manager class is used to manage and merge variables # down to a single dictionary for reference in templating, etc. @@ -185,10 +185,7 @@ class FieldAttributeBase: return self def get_ds(self): - try: - return getattr(self, '_ds') - except AttributeError: - return None + return self._ds def get_loader(self): return self._loader @@ -218,26 +215,14 @@ class FieldAttributeBase: if key not in valid_attrs: raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=key) - def validate(self, all_vars=None): + def validate(self): """ validation that is done at parse time, not load time """ - if not self._validated: - # walk all fields in the object - for (name, attribute) in self.fattributes.items(): - # run validator only if present - method = getattr(self, '_validate_%s' % name, None) - if method: - method(attribute, name, getattr(self, name)) - else: - # and make sure the attribute is of the type it should be - value = getattr(self, f'_{name}', Sentinel) - if value is not None: - if attribute.isa == 'string' and isinstance(value, (list, dict)): - raise AnsibleParserError( - "The field '%s' is supposed to be a string type," - " however the incoming data structure is a %s" % (name, type(value)), obj=self.get_ds() - ) + for name, attribute in self.fattributes.items(): + if (value := getattr(self, f'_{name}', Sentinel)) is Sentinel: + continue - self._validated = True + if method := getattr(self, f'_validate_{name}', None): + method(attribute, name, value) def _load_module_defaults(self, name, value): if value is None: @@ -297,8 +282,8 @@ class FieldAttributeBase: @property def play(self): - if hasattr(self, '_play'): - play = self._play + if _play := getattr(self, '_play', None): + play = _play elif hasattr(self, '_parent') and hasattr(self._parent, '_play'): play = self._parent._play else: @@ -410,17 +395,6 @@ class FieldAttributeBase: raise AnsibleParserError("Could not resolve action %s in module_defaults" % action_name) display.vvvvv("Could not resolve action %s in module_defaults" % action_name) - def squash(self): - """ - Evaluates all attributes and sets them to the evaluated version, - so that all future accesses of attributes do not need to evaluate - parent attributes. - """ - if not self._squashed: - for name in self.fattributes: - setattr(self, name, getattr(self, name)) - self._squashed = True - def copy(self): """ Create a copy of this object and return it. @@ -437,17 +411,17 @@ class FieldAttributeBase: new_me._loader = self._loader new_me._variable_manager = self._variable_manager new_me._origin = self._origin - new_me._validated = self._validated new_me._finalized = self._finalized new_me._uuid = self._uuid # if the ds value was set on the object, copy it to the new copy too - if hasattr(self, '_ds'): - new_me._ds = self._ds + if _ds := self.get_ds(): + new_me._ds = _ds return new_me - def get_validated_value(self, name, attribute, value, templar): + def get_validated_value(self, name: str, value: object, templar: TemplateEngine): + attribute: Attribute = self.fattributes[name] try: return self._get_validated_value(name, attribute, value, templar) except (TypeError, ValueError): @@ -455,6 +429,13 @@ class FieldAttributeBase: def _get_validated_value(self, name, attribute, value, templar): if attribute.isa == 'string': + if isinstance(value, (list, dict)): + # NOTE historically this check has been in validate() + raise AnsibleParserError( + message=f"The field {name!r} is supposed to be a string type, " + f"however the incoming data structure is a {type(value)}", + obj=self.get_ds(), + ) value = to_text(value) elif attribute.isa == 'int': if not isinstance(value, int): @@ -509,20 +490,8 @@ class FieldAttributeBase: def set_to_context(self, name: str) -> t.Any: """ set to parent inherited value or Sentinel as appropriate""" - - attribute = self.fattributes[name] - if isinstance(attribute, NonInheritableFieldAttribute): - # setting to sentinel will trigger 'default/default()' on getter - value = Sentinel - else: - try: - value = self._get_parent_attribute(name, omit=True) - except AttributeError: - # mostly playcontext as only tasks/handlers/blocks really resolve parent - value = Sentinel - - setattr(self, name, value) - return value + setattr(self, name, Sentinel) + return getattr(self, name, Sentinel) def post_validate(self, templar): """ @@ -541,33 +510,23 @@ class FieldAttributeBase: self._finalized = True def post_validate_attribute(self, name: str, *, templar: TemplateEngine): - attribute: FieldAttribute = self.fattributes[name] + attribute: Attribute = self.fattributes[name] - # DTFIX-FUTURE: this can probably be used in many getattr cases below, but the value may be out-of-date in some cases original_value = getattr(self, name) # we save this original (likely Origin-tagged) value to pass as `obj` for errors if attribute.static: - value = getattr(self, name) - # we don't template 'vars' but allow template as values for later use - if name not in ('vars',) and templar.is_template(value): + if name not in ('vars',) and templar.is_template(original_value): display.warning('"%s" is not templatable, but we found: %s, ' - 'it will not be templated and will be used "as is".' % (name, value)) + 'it will not be templated and will be used "as is".' % (name, original_value)) return Sentinel - if getattr(self, name) is None: + if original_value is None: if not attribute.required: return Sentinel raise AnsibleFieldAttributeError(f'The field {name!r} is required but was not set.', obj=self.get_ds()) - from .role_include import IncludeRole - - if not attribute.always_post_validate and isinstance(self, IncludeRole) and self.statically_loaded: # import_role - # normal field attributes should not go through post validation on import_role/import_tasks - # only import_role is checked here because import_tasks never reaches this point - return Sentinel - # Skip post validation unless always_post_validate is True, or the object requires post validation. if not attribute.always_post_validate and not self._post_validate_object: # Intermediate objects like Play() won't have their fields validated by @@ -581,13 +540,13 @@ class FieldAttributeBase: method = getattr(self, '_post_validate_%s' % name, None) if method: - value = method(attribute, getattr(self, name), templar) + value = method(attribute, original_value, templar) elif attribute.isa == 'class': - value = getattr(self, name) + value = original_value else: try: # if the attribute contains a variable, template it now - value = templar.template(getattr(self, name)) + value = templar.template(original_value) except AnsibleValueOmittedError: # If this evaluated to the omit value, set the value back to inherited by context # or default specified in the FieldAttribute and move on @@ -598,7 +557,7 @@ class FieldAttributeBase: # and make sure the attribute is of the type it should be if value is not None: - value = self.get_validated_value(name, attribute, value, templar) + value = self.get_validated_value(name, value, templar) # returning the value results in assigning the massaged value back to the attribute field return value @@ -627,31 +586,6 @@ class FieldAttributeBase: except TypeError as ex: raise AnsibleParserError(f"Invalid variable name in vars specified for {self.__class__.__name__}.", obj=ds) from ex - def _extend_value(self, value, new_value, prepend=False): - """ - Will extend the value given with new_value (and will turn both - into lists if they are not so already). The values are run through - a set to remove duplicate values. - """ - - if not isinstance(value, list): - value = [value] - if not isinstance(new_value, list): - new_value = [new_value] - - # Due to where _extend_value may run for some attributes - # it is possible to end up with Sentinel in the list of values - # ensure we strip them - value = [v for v in value if v is not Sentinel] - new_value = [v for v in new_value if v is not Sentinel] - - if prepend: - combined = new_value + value - else: - combined = value + new_value - - return [i for i, dummy in itertools.groupby(combined) if i is not None] - def dump_attrs(self): """ Dumps all attributes to a dictionary @@ -719,8 +653,9 @@ class Base(FieldAttributeBase): become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags')) become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe')) - # used to hold sudo/su stuff - DEPRECATED_ATTRIBUTES = [] # type: list[str] + def _validate_environment(self, attr, name, value): + if not isinstance(value, list): + setattr(self, name, [value]) def update_result_no_log(self, templar: TemplateEngine, result: dict[str, t.Any]) -> None: """Set the post-validated no_log value for the result, falling back to a default on validation/templating failure with a warning.""" @@ -761,12 +696,11 @@ class Base(FieldAttributeBase): return path - def get_dep_chain(self): - - if hasattr(self, '_parent') and self._parent: - return self._parent.get_dep_chain() + def get_dep_chain(self) -> list[Role]: + if role := getattr(self, '_role', None): + return role.get_dep_chain() + [role] else: - return None + return [] def get_search_path(self): """ @@ -775,9 +709,8 @@ class Base(FieldAttributeBase): """ path_stack = [] - dep_chain = self.get_dep_chain() # inside role: add the dependency chain from current to dependent - if dep_chain: + if dep_chain := self.get_dep_chain(): path_stack.extend(reversed([x._role_path for x in dep_chain if hasattr(x, '_role_path')])) # add path of task itself, unless it is already in the list diff --git a/lib/ansible/playbook/block.py b/lib/ansible/playbook/block.py index 81a2197e9de..dbeb128386f 100644 --- a/lib/ansible/playbook/block.py +++ b/lib/ansible/playbook/block.py @@ -18,7 +18,6 @@ from __future__ import annotations from ansible.errors import AnsibleParserError -from ansible.module_utils.common.sentinel import Sentinel from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base from ansible.playbook.conditional import Conditional @@ -40,13 +39,11 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab # similar to the 'else' clause for exceptions # otherwise = FieldAttribute(isa='list') - def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False): + def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False): self._play = play self._role = role self._parent = None - self._dep_chain = None self._use_handlers = use_handlers - self._implicit = implicit if task_include: self._parent = task_include @@ -77,25 +74,18 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab if self._parent: all_vars |= self._parent.get_vars() - all_vars |= self.vars.copy() + all_vars |= self.vars return all_vars @staticmethod def load(data, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None): - implicit = not Block.is_block(data) - b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers, implicit=implicit) + b = Block(play=play, parent_block=parent_block, role=role, task_include=task_include, use_handlers=use_handlers) return b.load_data(data, variable_manager=variable_manager, loader=loader) @staticmethod def is_block(ds): - is_block = False - if isinstance(ds, dict): - for attr in ('block', 'rescue', 'always'): - if attr in ds: - is_block = True - break - return is_block + return isinstance(ds, dict) and 'block' in ds def preprocess_data(self, ds): """ @@ -111,8 +101,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab return super(Block, self).preprocess_data(ds) - # FIXME: these do nothing but augment the exception message; DRY and nuke - def _load_block(self, attr, ds): try: return load_list_of_tasks( @@ -126,37 +114,9 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab use_handlers=self._use_handlers, ) except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading a block", obj=self._ds) from ex + raise AnsibleParserError(f"A malformed block was encountered while loading a {attr}", obj=self._ds) from ex - def _load_rescue(self, attr, ds): - try: - return load_list_of_tasks( - ds, - play=self._play, - block=self, - role=self._role, - task_include=None, - variable_manager=self._variable_manager, - loader=self._loader, - use_handlers=self._use_handlers, - ) - except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading rescue.", obj=self._ds) from ex - - def _load_always(self, attr, ds): - try: - return load_list_of_tasks( - ds, - play=self._play, - block=self, - role=self._role, - task_include=None, - variable_manager=self._variable_manager, - loader=self._loader, - use_handlers=self._use_handlers, - ) - except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading always", obj=self._ds) from ex + _load_rescue = _load_always = _load_block def _validate_always(self, attr, name, value): if value and not self.block: @@ -164,15 +124,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab _validate_rescue = _validate_always - def get_dep_chain(self): - if self._dep_chain is None: - if self._parent: - return self._parent.get_dep_chain() - else: - return None - else: - return self._dep_chain[:] - def copy(self, exclude_parent=False, exclude_tasks=False): def _dupe_task_list(task_list, new_block): new_task_list = [] @@ -199,9 +150,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab new_me._play = self._play new_me._use_handlers = self._use_handlers - if self._dep_chain is not None: - new_me._dep_chain = self._dep_chain[:] - new_me._parent = None if self._parent and not exclude_parent: new_me._parent = self._parent.copy(exclude_tasks=True) @@ -211,11 +159,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab new_me.rescue = _dupe_task_list(self.rescue or [], new_me) new_me.always = _dupe_task_list(self.always or [], new_me) - new_me._role = None - if self._role: - new_me._role = self._role + new_me._role = self._role - new_me.validate() return new_me def set_loader(self, loader): @@ -225,83 +170,8 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab elif self._role: self._role.set_loader(loader) - dep_chain = self.get_dep_chain() - if dep_chain: - for dep in dep_chain: - dep.set_loader(loader) - - def _get_parent_attribute(self, attr, omit=False): - """ - Generic logic to get the attribute or parent attribute for a block value. - """ - fattr = self.fattributes[attr] - - extend = fattr.extend - prepend = fattr.prepend - - try: - # omit self, and only get parent values - if omit: - value = Sentinel - else: - value = getattr(self, f'_{attr}', Sentinel) - - # If parent is static, we can grab attrs from the parent - # otherwise, defer to the grandparent - if getattr(self._parent, 'statically_loaded', True): - _parent = self._parent - else: - _parent = self._parent._parent - - if _parent and (value is Sentinel or extend): - try: - if getattr(_parent, 'statically_loaded', True): - if hasattr(_parent, '_get_parent_attribute'): - parent_value = _parent._get_parent_attribute(attr) - else: - parent_value = getattr(_parent, f'_{attr}', Sentinel) - if extend: - value = self._extend_value(value, parent_value, prepend) - else: - value = parent_value - except AttributeError: - pass - if self._role and (value is Sentinel or extend): - try: - parent_value = getattr(self._role, f'_{attr}', Sentinel) - if extend: - value = self._extend_value(value, parent_value, prepend) - else: - value = parent_value - - dep_chain = self.get_dep_chain() - if dep_chain and (value is Sentinel or extend): - dep_chain.reverse() - for dep in dep_chain: - dep_value = getattr(dep, f'_{attr}', Sentinel) - if extend: - value = self._extend_value(value, dep_value, prepend) - else: - value = dep_value - - if value is not Sentinel and not extend: - break - except AttributeError: - pass - if self._play and (value is Sentinel or extend): - try: - play_value = getattr(self._play, f'_{attr}', Sentinel) - if play_value is not Sentinel: - if extend: - value = self._extend_value(value, play_value, prepend) - else: - value = play_value - except AttributeError: - pass - except KeyError: - pass - - return value + for dep in self.get_dep_chain(): + dep.set_loader(loader) def filter_tagged_tasks(self, all_vars): """ @@ -348,7 +218,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab return evaluate_block(self) def has_tasks(self): - return len(self.block) > 0 or len(self.rescue) > 0 or len(self.always) > 0 + return bool(self.block or self.rescue or self.always) def get_include_params(self): if self._parent: @@ -356,21 +226,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab else: return dict() - def all_parents_static(self): - """ - Determine if all of the parents of this block were statically loaded - or not. Since Task/TaskInclude objects may be in the chain, they simply - call their parents all_parents_static() method. Only Block objects in - the chain check the statically_loaded value of the parent. - """ - from ansible.playbook.task_include import TaskInclude - if self._parent: - if isinstance(self._parent, TaskInclude) and not self._parent.statically_loaded: - return False - return self._parent.all_parents_static() - - return True - def get_first_parent_include(self): from ansible.playbook.task_include import TaskInclude if self._parent: diff --git a/lib/ansible/playbook/collectionsearch.py b/lib/ansible/playbook/collectionsearch.py index b0036a5b9e6..c6064aa7f6a 100644 --- a/lib/ansible/playbook/collectionsearch.py +++ b/lib/ansible/playbook/collectionsearch.py @@ -5,9 +5,6 @@ from __future__ import annotations from ansible.playbook.attribute import FieldAttribute from ansible.utils.collection_loader import AnsibleCollectionConfig -from ansible.utils.display import Display - -display = Display() def _ensure_default_collection(collection_list=None): @@ -36,7 +33,7 @@ class CollectionSearch: def _load_collections(self, attr, ds): # We are always a mixin with Base, so we can validate this untemplated # field early on to guarantee we are dealing with a list. - ds = self.get_validated_value('collections', self.fattributes.get('collections'), ds, None) + ds = self.get_validated_value('collections', ds, None) # this will only be called if someone specified a value; call the shared value _ensure_default_collection(collection_list=ds) diff --git a/lib/ansible/playbook/conditional.py b/lib/ansible/playbook/conditional.py index ac59259acb3..2c02231528f 100644 --- a/lib/ansible/playbook/conditional.py +++ b/lib/ansible/playbook/conditional.py @@ -18,9 +18,6 @@ from __future__ import annotations from ansible.playbook.attribute import FieldAttribute -from ansible.utils.display import Display - -display = Display() class Conditional: @@ -31,9 +28,6 @@ class Conditional: when = FieldAttribute(isa='list', default=list, extend=True, prepend=True) - def __init__(self, *args, **kwargs): - super().__init__() - def _validate_when(self, attr, name, value): if not isinstance(value, list): setattr(self, name, [value]) diff --git a/lib/ansible/playbook/handler.py b/lib/ansible/playbook/handler.py index 5038d71258f..e34f5e9e339 100644 --- a/lib/ansible/playbook/handler.py +++ b/lib/ansible/playbook/handler.py @@ -38,7 +38,7 @@ class Handler(Task): return "HANDLER: %s" % self.get_name() def _validate_listen(self, attr, name, value): - new_value = self.get_validated_value(name, attr, value, None) + new_value = self.get_validated_value(name, value, None) if self._role is not None: for listener in new_value.copy(): new_value.extend([ diff --git a/lib/ansible/playbook/handler_task_include.py b/lib/ansible/playbook/handler_task_include.py index 2a0388191c9..4dd6fd513f7 100644 --- a/lib/ansible/playbook/handler_task_include.py +++ b/lib/ansible/playbook/handler_task_include.py @@ -17,7 +17,6 @@ from __future__ import annotations -# from ansible.inventory.host import Host from ansible.playbook.handler import Handler from ansible.playbook.task_include import TaskInclude diff --git a/lib/ansible/playbook/helpers.py b/lib/ansible/playbook/helpers.py index e3e9fab7bfc..2dd753f7b7f 100644 --- a/lib/ansible/playbook/helpers.py +++ b/lib/ansible/playbook/helpers.py @@ -291,7 +291,7 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None, loader=None, collection_search_list=None): """ - Loads and returns a list of RoleInclude objects from the ds list of role definitions + Loads and returns a list of RoleDefinition objects from the ds list of role definitions :param ds: list of roles to load :param play: calling Play object :param current_role_path: path of the owning role, if any @@ -301,15 +301,15 @@ def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None, :return: """ # we import here to prevent a circular dependency with imports - from ansible.playbook.role.include import RoleInclude + from ansible.playbook.role.definition import RoleDefinition if not isinstance(ds, list): raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds))) roles = [] for role_def in ds: - i = RoleInclude.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager, - loader=loader, collection_list=collection_search_list) + i = RoleDefinition.load(role_def, play=play, current_role_path=current_role_path, variable_manager=variable_manager, + loader=loader, collection_list=collection_search_list) roles.append(i) return roles diff --git a/lib/ansible/playbook/loop_control.py b/lib/ansible/playbook/loop_control.py index c8e9af0e231..805b271e2c7 100644 --- a/lib/ansible/playbook/loop_control.py +++ b/lib/ansible/playbook/loop_control.py @@ -31,9 +31,6 @@ class LoopControl(FieldAttributeBase): extended_allitems = NonInheritableFieldAttribute(isa='bool', default=True, always_post_validate=True) break_when = NonInheritableFieldAttribute(isa='list', default=list) - def __init__(self): - super(LoopControl, self).__init__() - @staticmethod def load(data, variable_manager=None, loader=None): t = LoopControl() diff --git a/lib/ansible/playbook/play.py b/lib/ansible/playbook/play.py index 61592cd916e..426f7fd86a3 100644 --- a/lib/ansible/playbook/play.py +++ b/lib/ansible/playbook/play.py @@ -35,12 +35,9 @@ from ansible.playbook.role import Role from ansible.playbook.task import Task from ansible.playbook.taggable import Taggable from ansible.parsing.vault import EncryptedString -from ansible.utils.display import Display from ansible._internal._templating._engine import TemplateEngine as _TE -display = Display() - __all__ = ['Play'] @@ -73,7 +70,7 @@ class Play(Base, Taggable, CollectionSearch): validate_argspec = NonInheritableFieldAttribute(isa='string', always_post_validate=True) # Role Attributes - roles = NonInheritableFieldAttribute(isa='list', default=list, priority=90) + roles = NonInheritableFieldAttribute(isa='list', default=list, priority=-1) # Block (Task) Lists Attributes handlers = NonInheritableFieldAttribute(isa='list', default=list, priority=-1) @@ -156,27 +153,11 @@ class Play(Base, Taggable, CollectionSearch): """ Adjusts play datastructure to cleanup old/legacy items """ - if not isinstance(ds, dict): raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds))) - # The use of 'user' in the Play datastructure was deprecated to - # line up with the same change for Tasks, due to the fact that - # 'user' conflicted with the user module. - if 'user' in ds: - # this should never happen, but error out with a helpful message - # to the user if it does... - if 'remote_user' in ds: - raise AnsibleParserError("both 'user' and 'remote_user' are set for this play. " - "The use of 'user' is deprecated, and should be removed", obj=ds) - - ds['remote_user'] = ds['user'] - del ds['user'] - return super(Play, self).preprocess_data(ds) - # DTFIX-FUTURE: these do nothing but augment the exception message; DRY and nuke - def _load_tasks(self, attr, ds): """ Loads a list of blocks from a list which may be mixed tasks/blocks. @@ -185,27 +166,9 @@ class Play(Base, Taggable, CollectionSearch): try: return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading tasks.", obj=self._ds) from ex + raise AnsibleParserError(f"A malformed block was encountered while loading {attr}.", obj=self._ds) from ex - def _load_pre_tasks(self, attr, ds): - """ - Loads a list of blocks from a list which may be mixed tasks/blocks. - Bare tasks outside of a block are given an implicit block. - """ - try: - return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) - except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading pre_tasks.", obj=self._ds) from ex - - def _load_post_tasks(self, attr, ds): - """ - Loads a list of blocks from a list which may be mixed tasks/blocks. - Bare tasks outside of a block are given an implicit block. - """ - try: - return load_list_of_blocks(ds=ds, play=self, variable_manager=self._variable_manager, loader=self._loader) - except AssertionError as ex: - raise AnsibleParserError("A malformed block was encountered while loading post_tasks.", obj=self._ds) from ex + _load_pre_tasks = _load_post_tasks = _load_tasks def _load_handlers(self, attr, ds): """ @@ -213,17 +176,13 @@ class Play(Base, Taggable, CollectionSearch): Bare handlers outside of a block are given an implicit block. """ try: - return self._extend_value( - self.handlers, - load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader), - prepend=True - ) + return load_list_of_blocks(ds=ds, play=self, use_handlers=True, variable_manager=self._variable_manager, loader=self._loader) + self.handlers except AssertionError as ex: raise AnsibleParserError("A malformed block was encountered while loading handlers.", obj=self._ds) from ex def _load_roles(self, attr, ds): """ - Loads and returns a list of RoleInclude objects from the datastructure + Loads and returns a list of RoleDefinition objects from the datastructure list of role definitions and creates the Role from those objects """ @@ -382,21 +341,6 @@ class Play(Base, Taggable, CollectionSearch): return [self.vars_files] return self.vars_files - def get_handlers(self): - return self.handlers[:] - - def get_roles(self): - return self.roles[:] - - def get_tasks(self): - tasklist = [] - for task in self.pre_tasks + self.tasks + self.post_tasks: - if isinstance(task, Block): - tasklist.append(task.block + task.rescue + task.always) - else: - tasklist.append(task) - return tasklist - def copy(self): new_me = super(Play, self).copy() new_me.role_cache = self.role_cache.copy() diff --git a/lib/ansible/playbook/role/__init__.py b/lib/ansible/playbook/role/__init__.py index ab79c55765b..a3ade809824 100644 --- a/lib/ansible/playbook/role/__init__.py +++ b/lib/ansible/playbook/role/__init__.py @@ -41,12 +41,6 @@ from ansible.utils.display import Display from ansible.utils.path import is_subpath from ansible.utils.vars import combine_vars -# NOTE: This import is only needed for the type-checking in __init__. While there's an alternative -# available by using forward references this seems not to work well with commonly used IDEs. -# Therefore the TYPE_CHECKING hack seems to be a more universal approach, even if not being very elegant. -# References: -# * https://stackoverflow.com/q/39740632/199513 -# * https://peps.python.org/pep-0484/#forward-references if _t.TYPE_CHECKING: from ansible.playbook.block import Block from ansible.playbook.play import Play @@ -155,6 +149,8 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): self._completed: dict[str, bool] = dict() self._should_validate: bool = validate + self._dep_chain: list[Role] | None = None + if from_files is None: from_files = {} self._from_files: dict[str, list[str]] = from_files @@ -206,7 +202,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): return self._get_hash_dict() == other._get_hash_dict() @staticmethod - def load(role_include, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True): + def load(role_definition, play, parent_role=None, from_files=None, from_include=False, validate=True, public=None, static=True, rescuable=True): if from_files is None: from_files = {} try: @@ -215,7 +211,7 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): # that role?) # see https://github.com/ansible/ansible/issues/61527 r = Role(play=play, from_files=from_files, from_include=from_include, validate=validate, public=public, static=static, rescuable=rescuable) - r._load_role_data(role_include, parent_role=parent_role) + r._load_role_data(role_definition, parent_role=parent_role) role_path = r.get_role_path() if role_path not in play.role_cache: @@ -230,23 +226,23 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): except RecursionError as ex: raise AnsibleError("A recursion loop was detected with the roles specified. Make sure child roles do not have dependencies on parent roles", - obj=role_include._ds) from ex + obj=role_definition._ds) from ex - def _load_role_data(self, role_include, parent_role=None): - self._role_name = role_include.role - self._role_path = role_include.get_role_path() - self._role_collection = role_include._role_collection - self._role_params = role_include.get_role_params() - self._variable_manager = role_include.get_variable_manager() - self._loader = role_include.get_loader() + def _load_role_data(self, role_definition, parent_role=None): + self._role_name = role_definition.role + self._role_path = role_definition.get_role_path() + self._role_collection = role_definition._role_collection + self._role_params = role_definition.get_role_params() + self._variable_manager = role_definition.get_variable_manager() + self._loader = role_definition.get_loader() if parent_role: self.add_parent(parent_role) - # copy over all field attributes from the RoleInclude + # copy over all field attributes from the RoleDefinition # update self._attr directly, to avoid squashing for attr_name in self.fattributes: - setattr(self, f'_{attr_name}', getattr(role_include, f'_{attr_name}', Sentinel)) + setattr(self, f'_{attr_name}', getattr(role_definition, f'_{attr_name}', Sentinel)) # vars and default vars are regular dictionaries self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True) @@ -466,14 +462,12 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): """ deps = [] - for role_include in self._metadata.dependencies: - r = Role.load(role_include, play=self._play, parent_role=self, static=self.static) + for role_definition in self._metadata.dependencies: + r = Role.load(role_definition, play=self._play, parent_role=self, static=self.static) deps.append(r) return deps - # other functions - def add_parent(self, parent_role): """ adds a role to the list of this roles parents """ if not isinstance(parent_role, Role): @@ -485,12 +479,15 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): def get_parents(self): return self._parents - def get_dep_chain(self): - dep_chain = [] - for parent in self._parents: - dep_chain.extend(parent.get_dep_chain()) - dep_chain.append(parent) - return dep_chain + def get_dep_chain(self) -> list[Role]: + """Returns a copy of the parent chain list.""" + if self._dep_chain is None: + dep_chain = [] + for parent in self._parents: + dep_chain.extend(parent.get_dep_chain()) + dep_chain.append(parent) + self._dep_chain = dep_chain + return self._dep_chain[:] def get_default_vars(self, dep_chain=None): dep_chain = [] if dep_chain is None else dep_chain @@ -529,8 +526,6 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): def get_vars(self, dep_chain=None, include_params=True, only_exports=False): dep_chain = [] if dep_chain is None else dep_chain - all_vars = {} - # get role_vars: from parent objects # TODO: is this right precedence for inherited role_vars? all_vars = self.get_inherited_vars(dep_chain, only_exports=only_exports) @@ -586,23 +581,17 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): # # ``get_handler_blocks`` may be called when handling ``import_role`` during parsing # as well as with ``Play.compile_roles_handlers`` from ``TaskExecutor`` + # FIXME deprecate unused dep_chain parameter if self._compiled_handler_blocks: return self._compiled_handler_blocks self._compiled_handler_blocks = block_list = [] - # update the dependency chain here - if dep_chain is None: - dep_chain = [] - new_dep_chain = dep_chain + [self] - for dep in self.get_direct_dependencies(): - dep_blocks = dep.get_handler_blocks(play=play, dep_chain=new_dep_chain) - block_list.extend(dep_blocks) + block_list.extend(dep.get_handler_blocks(play=play)) for task_block in self._handler_blocks: new_task_block = task_block.copy() - new_task_block._dep_chain = new_dep_chain new_task_block._play = play block_list.append(new_task_block) @@ -626,24 +615,18 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable): with each task, so tasks know by which route they were found, and can correctly take their parent's tags/conditionals into account. """ + # FIXME deprecate unused dep_chain parameter from ansible.playbook.block import Block from ansible.playbook.task import Task block_list = [] - # update the dependency chain here - if dep_chain is None: - dep_chain = [] - new_dep_chain = dep_chain + [self] - - deps = self.get_direct_dependencies() - for dep in deps: - dep_blocks = dep.compile(play=play, dep_chain=new_dep_chain) + for dep in self.get_direct_dependencies(): + dep_blocks = dep.compile(play=play) block_list.extend(dep_blocks) for task_block in self._task_blocks: new_task_block = task_block.copy() - new_task_block._dep_chain = new_dep_chain new_task_block._play = play block_list.append(new_task_block) diff --git a/lib/ansible/playbook/role/definition.py b/lib/ansible/playbook/role/definition.py index 017344062eb..2a14058ba18 100644 --- a/lib/ansible/playbook/role/definition.py +++ b/lib/ansible/playbook/role/definition.py @@ -20,12 +20,13 @@ from __future__ import annotations import os from ansible import constants as C -from ansible.errors import AnsibleError, AnsibleAssertionError +from ansible.errors import AnsibleError, AnsibleAssertionError, AnsibleParserError from ansible.module_utils._internal._datatag import AnsibleTagHelper from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.conditional import Conditional +from ansible.playbook.delegatable import Delegatable from ansible.playbook.taggable import Taggable from ansible._internal._templating._engine import TemplateEngine from ansible.utils.collection_loader import AnsibleCollectionRef @@ -38,13 +39,12 @@ __all__ = ['RoleDefinition'] display = Display() -class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): +class RoleDefinition(Base, Conditional, Taggable, Delegatable, CollectionSearch): role = NonInheritableFieldAttribute(isa='string') def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None): - - super(RoleDefinition, self).__init__() + super().__init__() self._play = play self._variable_manager = variable_manager @@ -56,12 +56,16 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): self._role_params = dict() self._collection_list = collection_list - # def __repr__(self): - # return 'ROLEDEF: ' + self._attributes.get('role', '') - @staticmethod - def load(data, variable_manager=None, loader=None): - raise AnsibleError("not implemented") + def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None): + if not (isinstance(data, str) or isinstance(data, dict)): + raise AnsibleParserError("Invalid role definition.", obj=data) + + if isinstance(data, str) and ',' in data: + raise AnsibleError("Invalid old style role requirement: %s" % data) + + rd = RoleDefinition(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list) + return rd.load_data(data, variable_manager=variable_manager, loader=loader) def preprocess_data(self, ds): # role names that are simply numbers can be parsed by PyYAML diff --git a/lib/ansible/playbook/role/include.py b/lib/ansible/playbook/role/include.py deleted file mode 100644 index a9eaeb9f12f..00000000000 --- a/lib/ansible/playbook/role/include.py +++ /dev/null @@ -1,49 +0,0 @@ -# (c) 2014 Michael DeHaan, -# -# This file is part of Ansible -# -# Ansible is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Ansible is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Ansible. If not, see . - -from __future__ import annotations - -from ansible.errors import AnsibleError, AnsibleParserError -from ansible.playbook.delegatable import Delegatable -from ansible.playbook.role.definition import RoleDefinition - - -__all__ = ['RoleInclude'] - - -class RoleInclude(RoleDefinition, Delegatable): - - """ - A derivative of RoleDefinition, used by playbook code when a role - is included for execution in a play. - """ - - def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None): - super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager, - loader=loader, collection_list=collection_list) - - @staticmethod - def load(data, play, current_role_path=None, parent_role=None, variable_manager=None, loader=None, collection_list=None): - - if not (isinstance(data, str) or isinstance(data, dict)): - raise AnsibleParserError("Invalid role definition.", obj=data) - - if isinstance(data, str) and ',' in data: - raise AnsibleError("Invalid old style role requirement: %s" % data) - - ri = RoleInclude(play=play, role_basedir=current_role_path, variable_manager=variable_manager, loader=loader, collection_list=collection_list) - return ri.load_data(data, variable_manager=variable_manager, loader=loader) diff --git a/lib/ansible/playbook/role/metadata.py b/lib/ansible/playbook/role/metadata.py index e4567d7269d..c856954fcc1 100644 --- a/lib/ansible/playbook/role/metadata.py +++ b/lib/ansible/playbook/role/metadata.py @@ -21,7 +21,7 @@ import os from ansible.errors import AnsibleParserError, AnsibleError from ansible.playbook.attribute import NonInheritableFieldAttribute -from ansible.playbook.base import Base +from ansible.playbook.base import FieldAttributeBase from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.helpers import load_list_of_roles from ansible.playbook.role.requirement import RoleRequirement @@ -29,7 +29,7 @@ from ansible.playbook.role.requirement import RoleRequirement __all__ = ['RoleMetadata'] -class RoleMetadata(Base, CollectionSearch): +class RoleMetadata(FieldAttributeBase, CollectionSearch): """ This class wraps the parsing and validation of the optional metadata within each Role (meta/main.yml). @@ -59,7 +59,7 @@ class RoleMetadata(Base, CollectionSearch): def _load_dependencies(self, attr, ds): """ This is a helper loading function for the dependencies list, - which returns a list of RoleInclude objects + which returns a list of RoleDefinition objects """ roles = [] diff --git a/lib/ansible/playbook/role/requirement.py b/lib/ansible/playbook/role/requirement.py index 716ad51b233..9acbb5807d7 100644 --- a/lib/ansible/playbook/role/requirement.py +++ b/lib/ansible/playbook/role/requirement.py @@ -18,8 +18,6 @@ from __future__ import annotations from ansible.errors import AnsibleError -from ansible.playbook.role.definition import RoleDefinition -from ansible.utils.display import Display from ansible.utils.galaxy import scm_archive_resource __all__ = ['RoleRequirement'] @@ -32,19 +30,13 @@ VALID_SPEC_KEYS = [ 'version', ] -display = Display() - - -class RoleRequirement(RoleDefinition): +class RoleRequirement: """ Helper class for Galaxy, which is used to parse both dependencies specified in meta/main.yml and requirements.yml files. """ - def __init__(self): - pass - @staticmethod def repo_url_to_role_name(repo_url): # gets the role name out of a repo like diff --git a/lib/ansible/playbook/role_include.py b/lib/ansible/playbook/role_include.py index 3c44ed340ab..91382d5f58a 100644 --- a/lib/ansible/playbook/role_include.py +++ b/lib/ansible/playbook/role_include.py @@ -21,14 +21,11 @@ from ansible.errors import AnsibleError, AnsibleParserError from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.task_include import TaskInclude from ansible.playbook.role import Role -from ansible.playbook.role.include import RoleInclude -from ansible.utils.display import Display +from ansible.playbook.role.definition import RoleDefinition from ansible._internal._templating._engine import TemplateEngine __all__ = ['IncludeRole'] -display = Display() - class IncludeRole(TaskInclude): @@ -59,6 +56,11 @@ class IncludeRole(TaskInclude): self._parent_role = role self._role_name = None self._role_path = None + self.statically_loaded = False + + @property + def _post_validate_object(self): + return not self.statically_loaded def get_name(self): """ return the name of the task """ @@ -73,13 +75,13 @@ class IncludeRole(TaskInclude): myplay = play try: - ri = RoleInclude.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections) + rd = RoleDefinition.load(self._role_name, play=myplay, variable_manager=variable_manager, loader=loader, collection_list=self.collections) except AnsibleError as e: if not self.rescuable: raise AnsibleParserError("Could not include role.") from e raise - ri.vars |= self.vars + rd.vars |= self.vars if variable_manager is not None: available_variables = variable_manager.get_vars(play=myplay, task=self) @@ -89,7 +91,7 @@ class IncludeRole(TaskInclude): from_files = templar.template(self._from_files) # build role - actual_role = Role.load(ri, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True, + actual_role = Role.load(rd, myplay, parent_role=self._parent_role, from_files=from_files, from_include=True, validate=self.rolespec_validate, public=self.public, static=self.statically_loaded, rescuable=self.rescuable) actual_role._metadata.allow_duplicates = self.allow_duplicates @@ -99,23 +101,19 @@ class IncludeRole(TaskInclude): # save this for later use self._role_path = actual_role._role_path - # compile role with parent roles as dependencies to ensure they inherit - # variables - dep_chain = actual_role.get_dep_chain() - p_block = self.build_parent_block() # collections value is not inherited; override with the value we calculated during role setup p_block.collections = actual_role.collections - blocks = actual_role.compile(play=myplay, dep_chain=dep_chain) + blocks = actual_role.compile(play=myplay) for b in blocks: b._parent = p_block # HACK: parent inheritance doesn't seem to have a way to handle this intermediate override until squashed/finalized b.collections = actual_role.collections # updated available handlers in play - handlers = actual_role.get_handler_blocks(play=myplay, dep_chain=dep_chain) + handlers = actual_role.get_handler_blocks(play=myplay) for h in handlers: h._parent = p_block myplay.handlers = myplay.handlers + handlers diff --git a/lib/ansible/playbook/taggable.py b/lib/ansible/playbook/taggable.py index 98d91cac65e..3b3256264a1 100644 --- a/lib/ansible/playbook/taggable.py +++ b/lib/ansible/playbook/taggable.py @@ -71,7 +71,7 @@ class Taggable: obj = obj._parent - yield self.get_play() + yield self.play def evaluate_tags(self, only_tags, skip_tags, all_vars): """Check if the current item should be executed depending on the specified tags. diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py index 0876c0fa0d6..ed58c3a1f68 100644 --- a/lib/ansible/playbook/task.py +++ b/lib/ansible/playbook/task.py @@ -20,7 +20,6 @@ from __future__ import annotations import typing as t from ansible import constants as C -from ansible.module_utils.common.sentinel import Sentinel from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError, AnsibleValueOmittedError from ansible.executor.module_common import _get_action_arg_defaults from ansible.module_utils.common.text.converters import to_native @@ -30,7 +29,6 @@ from ansible.plugins.action import ActionBase from ansible.plugins.loader import action_loader, module_loader, lookup_loader from ansible.playbook.attribute import NonInheritableFieldAttribute from ansible.playbook.base import Base -from ansible.playbook.block import Block from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.conditional import Conditional from ansible.playbook.delegatable import Delegatable @@ -97,7 +95,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl """ constructors a task, without the Task.load classmethod, it will be pretty blank """ self._role = role - self._parent = None self.implicit = False self._resolved_action: str | None = None @@ -156,20 +153,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl else: return "%s" % (self.action,) - def _merge_kv(self, ds): - if ds is None: - return "" - elif isinstance(ds, str): - return ds - elif isinstance(ds, dict): - buf = "" - for (k, v) in ds.items(): - if k.startswith('_'): - continue - buf = buf + "%s=%s " % (k, v) - buf = buf.strip() - return buf - @staticmethod def load(data, block=None, role=None, task_include=None, variable_manager=None, loader=None): task = Task(block=block, role=role, task_include=task_include) @@ -283,7 +266,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl else: # Validate this untemplated field early on to guarantee we are dealing with a list. # This is also done in CollectionSearch._load_collections() but this runs before that call. - collections_list = self.get_validated_value('collections', self.fattributes.get('collections'), collections_list, None) + collections_list = self.get_validated_value('collections', collections_list, None) if default_collection and not self._role: # FIXME: and not a collections role if collections_list: @@ -375,17 +358,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl except Exception as ex: raise AnsibleParserError("Invalid 'register' specified.", obj=value) from ex - def post_validate(self, templar): - """ - Override of base class post_validate, to also do final validation on - the block and task include (if any) to which this task belongs. - """ - - if self._parent: - self._parent.post_validate(templar) - - super(Task, self).post_validate(templar) - def _post_validate_loop(self, attr, value, templar): """ Override post validation for the loop field, which is templated @@ -425,7 +397,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl return raise - # NB: the environment FieldAttribute definition ensures that value is always a list for env_item in value: if isinstance(env_item, dict): for k in env_item: @@ -471,10 +442,8 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl all_vars |= self.vars - if 'tags' in all_vars: - del all_vars['tags'] - if 'when' in all_vars: - del all_vars['when'] + all_vars.pop('tags', None) + all_vars.pop('when', None) return all_vars @@ -482,8 +451,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl all_vars = dict() if self._parent: all_vars |= self._parent.get_include_params() - if self.action in C._ACTION_ALL_INCLUDES: - all_vars |= self.vars return all_vars def copy(self, exclude_parent: bool = False, exclude_tasks: bool = False) -> Task: @@ -493,13 +460,10 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl if self._parent and not exclude_parent: new_me._parent = self._parent.copy(exclude_tasks=exclude_tasks) - new_me._role = None - if self._role: - new_me._role = self._role + new_me._role = self._role new_me.implicit = self.implicit new_me._resolved_action = self._resolved_action - new_me._uuid = self._uuid return new_me @@ -515,51 +479,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl if self._parent: self._parent.set_loader(loader) - def _get_parent_attribute(self, attr, omit=False): - """ - Generic logic to get the attribute or parent attribute for a task value. - """ - fattr = self.fattributes[attr] - - extend = fattr.extend - prepend = fattr.prepend - - try: - # omit self, and only get parent values - if omit: - value = Sentinel - else: - value = getattr(self, f'_{attr}', Sentinel) - - # If parent is static, we can grab attrs from the parent - # otherwise, defer to the grandparent - if getattr(self._parent, 'statically_loaded', True): - _parent = self._parent - else: - _parent = self._parent._parent - - if _parent and (value is Sentinel or extend): - if getattr(_parent, 'statically_loaded', True): - # vars are always inheritable, other attributes might not be for the parent but still should be for other ancestors - if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'): - parent_value = _parent._get_parent_attribute(attr) - else: - parent_value = getattr(_parent, f'_{attr}', Sentinel) - - if extend: - value = self._extend_value(value, parent_value, prepend) - else: - value = parent_value - except KeyError: - pass - - return value - - def all_parents_static(self): - if self._parent: - return self._parent.all_parents_static() - return True - def get_first_parent_include(self): from ansible.playbook.task_include import TaskInclude if self._parent: @@ -568,12 +487,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl return self._parent.get_first_parent_include() return None - def get_play(self): - parent = self._parent - while not isinstance(parent, Block): - parent = parent._parent - return parent._play - def dump_attrs(self): """Override to smuggle important non-FieldAttribute values back to the controller.""" attrs = super().dump_attrs() @@ -585,10 +498,9 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl # from_attrs is only used to create a finalized task # from attrs from the Worker/TaskExecutor - # Those attrs are finalized and squashed in the TE + # Those attrs are finalized in the TE # and controller side use needs to reflect that self._finalized = True - self._squashed = True def _resolve_conditional( self, diff --git a/lib/ansible/playbook/task_include.py b/lib/ansible/playbook/task_include.py index 4bb0f2114a6..3b6c7786598 100644 --- a/lib/ansible/playbook/task_include.py +++ b/lib/ansible/playbook/task_include.py @@ -123,3 +123,8 @@ class TaskInclude(Task): p_block = self return p_block + + def get_include_params(self): + v = super().get_include_params() + v |= self.vars + return v diff --git a/lib/ansible/vars/manager.py b/lib/ansible/vars/manager.py index fb4970cd749..8d083b9cc19 100644 --- a/lib/ansible/vars/manager.py +++ b/lib/ansible/vars/manager.py @@ -542,7 +542,7 @@ class VariableManager: delegated_vars['ansible_delegated_vars'] = { delegated_host_name: self.get_vars( - play=task.get_play(), + play=task.play, host=delegated_host, task=task, include_hostvars=True, diff --git a/test/integration/targets/playbook/remote_user_and_user.yml b/test/integration/targets/playbook/remote_user_and_user.yml deleted file mode 100644 index c9e2389dfe5..00000000000 --- a/test/integration/targets/playbook/remote_user_and_user.yml +++ /dev/null @@ -1,6 +0,0 @@ -- hosts: localhost - remote_user: a - user: b - tasks: - - debug: - msg: did not run diff --git a/test/integration/targets/playbook/runme.sh b/test/integration/targets/playbook/runme.sh index bf4f1769b54..7eebb3d68cc 100755 --- a/test/integration/targets/playbook/runme.sh +++ b/test/integration/targets/playbook/runme.sh @@ -8,20 +8,10 @@ ansible-playbook -i ../../inventory types.yml -v "$@" # test timeout ansible-playbook -i ../../inventory timeout.yml -v "$@" -# our Play class allows for 'user' or 'remote_user', but not both. -# first test that both user and remote_user work individually set +e result="$(ansible-playbook -i ../../inventory user.yml -v "$@" 2>&1)" set -e -grep -q "worked with user" <<< "$result" -grep -q "worked with remote_user" <<< "$result" - -# then test that the play errors if user and remote_user both exist -echo "EXPECTED ERROR: Ensure we fail properly if a play has both user and remote_user." -set +e -result="$(ansible-playbook -i ../../inventory remote_user_and_user.yml -v "$@" 2>&1)" -set -e -grep -q "both 'user' and 'remote_user' are set for this play." <<< "$result" +grep -q "is not a valid attribute for a Play" <<< "$result" # test that playbook errors if len(plays) == 0 echo "EXPECTED ERROR: Ensure we fail properly if a playbook is an empty list." diff --git a/test/integration/targets/playbook/user.yml b/test/integration/targets/playbook/user.yml index 8b4029b82d3..510d4198274 100644 --- a/test/integration/targets/playbook/user.yml +++ b/test/integration/targets/playbook/user.yml @@ -13,11 +13,6 @@ - hosts: localhost user: "{{ me }}" tasks: - - debug: + - name: should not happen + debug: msg: worked with user ({{ me }}) - -- hosts: localhost - remote_user: "{{ me }}" - tasks: - - debug: - msg: worked with remote_user ({{ me }}) diff --git a/test/sanity/ignore.txt b/test/sanity/ignore.txt index bb296a812d9..c37f7c26270 100644 --- a/test/sanity/ignore.txt +++ b/test/sanity/ignore.txt @@ -50,7 +50,6 @@ lib/ansible/module_utils/six/__init__.py pylint:trailing-comma-tuple lib/ansible/module_utils/six/__init__.py pylint:unidiomatic-typecheck lib/ansible/module_utils/six/__init__.py replace-urlopen lib/ansible/module_utils/urls.py replace-urlopen -lib/ansible/playbook/role/include.py pylint:arguments-renamed lib/ansible/plugins/action/normal.py action-plugin-docs # default action plugin for modules without a dedicated action plugin lib/ansible/plugins/cache/base.py ansible-doc!skip # not a plugin, but a stub for backwards compatibility lib/ansible/plugins/callback/__init__.py pylint:arguments-renamed diff --git a/test/units/playbook/role/test_role.py b/test/units/playbook/role/test_role.py index cbfe776357e..1e9681104a9 100644 --- a/test/units/playbook/role/test_role.py +++ b/test/units/playbook/role/test_role.py @@ -31,7 +31,7 @@ from units.mock.loader import DictDataLoader from units.mock.path import mock_unfrackpath_noop from ansible.playbook.role import Role -from ansible.playbook.role.include import RoleInclude +from ansible.playbook.role.definition import RoleDefinition from ansible.playbook.role import hash_params @@ -168,7 +168,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(str(r), 'foo_tasks') @@ -190,7 +190,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_tasks', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_tasks', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play, from_files=dict(tasks='custom_main')) self.assertEqual(r._task_blocks[0]._ds[0]['command'], 'baz') @@ -208,7 +208,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_handlers', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_handlers', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(len(r._handler_blocks), 1) @@ -229,7 +229,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r._default_vars, dict(foo='bar')) @@ -250,7 +250,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r._default_vars, dict(foo='bar')) @@ -271,7 +271,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r._default_vars, dict(foo='bar')) @@ -294,7 +294,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r._default_vars, dict(foo='bar', a=1, b=2)) @@ -314,7 +314,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load('foo_vars', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_vars', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r._role_vars, dict(foo='bam')) @@ -361,7 +361,7 @@ class TestRole(unittest.TestCase): mock_play.collections = None mock_play.role_cache = {} - i = RoleInclude.load('foo_metadata', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('foo_metadata', play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) role_deps = r.get_direct_dependencies() @@ -379,16 +379,16 @@ class TestRole(unittest.TestCase): self.assertEqual(all_deps[1].get_name(), 'baz_metadata') self.assertEqual(all_deps[2].get_name(), 'bar_metadata') - i = RoleInclude.load('bad1_metadata', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('bad1_metadata', play=mock_play, loader=fake_loader) self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play) - i = RoleInclude.load('bad2_metadata', play=mock_play, loader=fake_loader) + i = RoleDefinition.load('bad2_metadata', play=mock_play, loader=fake_loader) self.assertRaises(AnsibleParserError, Role.load, i, play=mock_play) # TODO: re-enable this test once Ansible has proper role dep cycle detection # that doesn't rely on stack overflows being recoverable (as they aren't in Py3.7+) # see https://github.com/ansible/ansible/issues/61527 - # i = RoleInclude.load('recursive1_metadata', play=mock_play, loader=fake_loader) + # i = RoleDefinition.load('recursive1_metadata', play=mock_play, loader=fake_loader) # self.assertRaises(AnsibleError, Role.load, i, play=mock_play) @patch('ansible.playbook.role.definition.unfrackpath', mock_unfrackpath_noop) @@ -406,7 +406,7 @@ class TestRole(unittest.TestCase): mock_play = MagicMock() mock_play.role_cache = {} - i = RoleInclude.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader) + i = RoleDefinition.load(dict(role='foo_complex'), play=mock_play, loader=fake_loader) r = Role.load(i, play=mock_play) self.assertEqual(r.get_name(), "foo_complex") diff --git a/test/units/playbook/test_attribute.py b/test/units/playbook/test_attribute.py index 14c4807dd03..54d0bb5e713 100644 --- a/test/units/playbook/test_attribute.py +++ b/test/units/playbook/test_attribute.py @@ -27,30 +27,7 @@ class TestAttribute(unittest.TestCase): self.one = Attribute(priority=100) self.two = Attribute(priority=0) - def test_eq(self): - self.assertTrue(self.one == self.one) - self.assertFalse(self.one == self.two) - - def test_ne(self): - self.assertFalse(self.one != self.one) - self.assertTrue(self.one != self.two) - def test_lt(self): self.assertFalse(self.one < self.one) self.assertTrue(self.one < self.two) self.assertFalse(self.two < self.one) - - def test_gt(self): - self.assertFalse(self.one > self.one) - self.assertFalse(self.one > self.two) - self.assertTrue(self.two > self.one) - - def test_le(self): - self.assertTrue(self.one <= self.one) - self.assertTrue(self.one <= self.two) - self.assertFalse(self.two <= self.one) - - def test_ge(self): - self.assertTrue(self.one >= self.one) - self.assertFalse(self.one >= self.two) - self.assertTrue(self.two >= self.one) diff --git a/test/units/playbook/test_base.py b/test/units/playbook/test_base.py index a3fed9cf65f..25a90496e66 100644 --- a/test/units/playbook/test_base.py +++ b/test/units/playbook/test_base.py @@ -47,8 +47,6 @@ class TestBase(unittest.TestCase): bsc = self.ClassUnderTest() parent = ExampleParentBaseSubClass() bsc._parent = parent - bsc._dep_chain = [parent] - parent._dep_chain = None bsc.load_data(ds) fake_loader = DictDataLoader({}) templar = TemplateEngine(loader=fake_loader) @@ -110,11 +108,8 @@ class TestBase(unittest.TestCase): def test_load_data_invalid_attr_type(self): ds = {'environment': True} - - # environment is supposed to be a list. This - # seems like it shouldn't work? ret = self.b.load_data(ds) - self.assertEqual(True, ret._environment) + self.assertEqual([True], ret._environment) def test_post_validate(self): ds = {'environment': [], @@ -170,10 +165,6 @@ class TestBase(unittest.TestCase): b = self._base_validate(ds) self.assertEqual(b.vars, {}) - def test_validate_empty(self): - self.b.validate() - self.assertTrue(self.b._validated) - def test_getters(self): # not sure why these exist, but here are tests anyway loader = self.b.get_loader() @@ -182,70 +173,6 @@ class TestBase(unittest.TestCase): self.assertEqual(variable_manager, self.b._variable_manager) -class TestExtendValue(unittest.TestCase): - # _extend_value could be a module or staticmethod but since its - # not, the test is here. - def test_extend_value_list_newlist(self): - b = base.Base() - value_list = ['first', 'second'] - new_value_list = ['new_first', 'new_second'] - ret = b._extend_value(value_list, new_value_list) - self.assertEqual(value_list + new_value_list, ret) - - def test_extend_value_list_newlist_prepend(self): - b = base.Base() - value_list = ['first', 'second'] - new_value_list = ['new_first', 'new_second'] - ret_prepend = b._extend_value(value_list, new_value_list, prepend=True) - self.assertEqual(new_value_list + value_list, ret_prepend) - - def test_extend_value_newlist_list(self): - b = base.Base() - value_list = ['first', 'second'] - new_value_list = ['new_first', 'new_second'] - ret = b._extend_value(new_value_list, value_list) - self.assertEqual(new_value_list + value_list, ret) - - def test_extend_value_newlist_list_prepend(self): - b = base.Base() - value_list = ['first', 'second'] - new_value_list = ['new_first', 'new_second'] - ret = b._extend_value(new_value_list, value_list, prepend=True) - self.assertEqual(value_list + new_value_list, ret) - - def test_extend_value_string_newlist(self): - b = base.Base() - some_string = 'some string' - new_value_list = ['new_first', 'new_second'] - ret = b._extend_value(some_string, new_value_list) - self.assertEqual([some_string] + new_value_list, ret) - - def test_extend_value_string_newstring(self): - b = base.Base() - some_string = 'some string' - new_value_string = 'this is the new values' - ret = b._extend_value(some_string, new_value_string) - self.assertEqual([some_string, new_value_string], ret) - - def test_extend_value_list_newstring(self): - b = base.Base() - value_list = ['first', 'second'] - new_value_string = 'this is the new values' - ret = b._extend_value(value_list, new_value_string) - self.assertEqual(value_list + [new_value_string], ret) - - def test_extend_value_none_none(self): - b = base.Base() - ret = b._extend_value(None, None) - self.assertEqual(len(ret), 0) - self.assertFalse(ret) - - def test_extend_value_none_list(self): - b = base.Base() - ret = b._extend_value(None, ['foo']) - self.assertEqual(ret, ['foo']) - - class ExampleException(Exception): pass @@ -255,12 +182,7 @@ class ExampleParentBaseSubClass(base.Base): test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing') def __init__(self): - super(ExampleParentBaseSubClass, self).__init__() - self._dep_chain = None - - def get_dep_chain(self): - return self._dep_chain class ExampleSubClass(base.Base): @@ -281,7 +203,7 @@ class BaseSubClass(base.Base): test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True) test_attr_list_required = FieldAttribute(isa='list', listof=(str,), required=True, default=list, always_post_validate=True) - test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value') + test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value', always_post_validate=True) test_attr_string_required = FieldAttribute(isa='string', required=True, default='the_test_attr_string_default_value') test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True) @@ -299,9 +221,6 @@ class BaseSubClass(base.Base): test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter', always_post_validate=True) - def _get_attr_test_attr_method(self): - return 'foo bar' - def _validate_test_attr_example(self, attr, name, value): if not isinstance(value, str): raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value))) @@ -333,13 +252,6 @@ class TestBaseSubClass(TestBase): bsc = self._base_validate(ds) self.assertEqual(bsc.test_attr_int, MOST_RANDOM_NUMBER) - def test_attr_int_del(self): - MOST_RANDOM_NUMBER = 37 - ds = {'test_attr_int': MOST_RANDOM_NUMBER} - bsc = self._base_validate(ds) - del bsc.test_attr_int - self.assertNotIn('_test_attr_int', bsc.__dict__) - def test_attr_float(self): roughly_pi = 4.0 ds = {'test_attr_float': roughly_pi} @@ -446,7 +358,7 @@ class TestBaseSubClass(TestBase): def test_attr_string_invalid_list(self): ds = {'test_attr_string': ['The new test_attr_string', 'value, however in a list']} - self.assertRaises(AnsibleParserError, self._base_validate, ds) + self.assertRaises(AnsibleFieldAttributeError, self._base_validate, ds) def test_attr_string_required(self): the_string_value = "the new test_attr_string_required_value" @@ -512,12 +424,6 @@ class TestBaseSubClass(TestBase): {'test_attr_unknown_isa': True} ) - def test_attr_method(self): - ds = {'test_attr_method': 'value from the ds'} - bsc = self._base_validate(ds) - # The value returned by the subclasses _get_attr_test_attr_method - self.assertEqual(bsc.test_attr_method, 'foo bar') - def test_attr_method_missing(self): a_string = 'The value set from the ds' ds = {'test_attr_method_missing': a_string} @@ -525,10 +431,9 @@ class TestBaseSubClass(TestBase): self.assertEqual(bsc.test_attr_method_missing, a_string) def test_get_validated_value_string_preserve_tags(self): - attribute = FieldAttribute(isa='string') value = TrustedAsTemplate().tag('bar') templar = TemplateEngine(None) bsc = self.ClassUnderTest() - result = bsc.get_validated_value('foo', attribute, value, templar) + result = bsc.get_validated_value('test_attr_string', value, templar) assert TrustedAsTemplate.is_tagged_on(result) assert result == 'bar' diff --git a/test/units/playbook/test_helpers.py b/test/units/playbook/test_helpers.py index a96ce8f9231..9bee074e446 100644 --- a/test/units/playbook/test_helpers.py +++ b/test/units/playbook/test_helpers.py @@ -32,7 +32,7 @@ from ansible.playbook.block import Block from ansible.playbook.handler import Handler from ansible.playbook.task import Task from ansible.playbook.task_include import TaskInclude -from ansible.playbook.role.include import RoleInclude +from ansible.playbook.role.definition import RoleDefinition class MixinForMocks(object): @@ -317,7 +317,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks): variable_manager=self.mock_variable_manager, loader=self.fake_role_loader) self.assertIsInstance(res, list) for r in res: - self.assertIsInstance(r, RoleInclude) + self.assertIsInstance(r, RoleDefinition) def test_block_unknown_action(self): ds = [{ @@ -328,7 +328,7 @@ class TestLoadListOfRoles(unittest.TestCase, MixinForMocks): variable_manager=self.mock_variable_manager, loader=self.fake_role_loader) self.assertIsInstance(res, list) for r in res: - self.assertIsInstance(r, RoleInclude) + self.assertIsInstance(r, RoleDefinition) @pytest.mark.usefixtures('collection_loader') diff --git a/test/units/playbook/test_play.py b/test/units/playbook/test_play.py index 272744d5780..2107d55a9ea 100644 --- a/test/units/playbook/test_play.py +++ b/test/units/playbook/test_play.py @@ -50,17 +50,6 @@ def test_basic_play(): assert p.connection == 'local' -def test_play_with_remote_user(): - p = Play.load(dict( - name="test play", - hosts=['foo'], - user="testing", - gather_facts=False, - )) - - assert p.remote_user == "testing" - - def test_play_with_user_conflict(): play_data = dict( name="test play", @@ -101,7 +90,6 @@ def test_play_with_handlers(): )) assert len(p.handlers) >= 1 - assert len(p.get_handlers()) >= 1 assert isinstance(p.handlers[0], Block) assert p.handlers[0].has_tasks() is True @@ -118,9 +106,8 @@ def test_play_with_pre_tasks(): assert isinstance(p.pre_tasks[0], Block) assert p.pre_tasks[0].has_tasks() is True - assert len(p.get_tasks()) >= 1 - assert isinstance(p.get_tasks()[0][0], Task) - assert p.get_tasks()[0][0].action == 'shell' + assert isinstance(p.pre_tasks[0].block[0], Task) + assert p.pre_tasks[0].block[0].action == 'shell' def test_play_with_post_tasks(): @@ -158,7 +145,7 @@ def test_play_with_roles(mocker): blocks = p.compile() assert len(blocks) > 1 assert all(isinstance(block, Block) for block in blocks) - assert isinstance(p.get_roles()[0], Role) + assert isinstance(p.roles[0], Role) def test_play_compile(): diff --git a/test/units/playbook/test_taggable.py b/test/units/playbook/test_taggable.py index fe713ba9533..5b5c013486a 100644 --- a/test/units/playbook/test_taggable.py +++ b/test/units/playbook/test_taggable.py @@ -28,9 +28,10 @@ class TaggableTestObj(Taggable): self._loader = DictDataLoader({}) self.tags = [] self._parent = None + self.play = None - def get_play(self): - return None + def finalized(self): + return False class TestTaggable(unittest.TestCase): diff --git a/test/units/playbook/test_task.py b/test/units/playbook/test_task.py index 87c2a3a9308..09850aa880f 100644 --- a/test/units/playbook/test_task.py +++ b/test/units/playbook/test_task.py @@ -56,7 +56,7 @@ class TestTask(unittest.TestCase): p = dict(delay=delay) p.update(task_base) t = Task().load_data(p) - self.assertEqual(t.get_validated_value('delay', t.fattributes.get('delay'), delay, None), expected) + self.assertEqual(t.get_validated_value('delay', delay, None), expected) bad_params = [ 'E', @@ -69,7 +69,7 @@ class TestTask(unittest.TestCase): p.update(task_base) t = Task().load_data(p) with self.assertRaises(AnsibleError): - dummy = t.get_validated_value('delay', t.fattributes.get('delay'), delay, None) + dummy = t.get_validated_value('delay', delay, None) def test_task_auto_name_with_role(self): pass diff --git a/test/units/vars/test_variable_manager.py b/test/units/vars/test_variable_manager.py index 813767bb289..732f5512d93 100644 --- a/test/units/vars/test_variable_manager.py +++ b/test/units/vars/test_variable_manager.py @@ -84,7 +84,6 @@ class TestVariableManager(unittest.TestCase): mock_play = MagicMock() mock_play.get_vars.return_value = dict(foo="bar") - mock_play.get_roles.return_value = [] mock_play.get_vars_files.return_value = [] mock_inventory = MagicMock() @@ -103,7 +102,6 @@ class TestVariableManager(unittest.TestCase): mock_play = MagicMock() mock_play.get_vars.return_value = dict() - mock_play.get_roles.return_value = [] mock_play.get_vars_files.return_value = [__file__] mock_inventory = MagicMock() @@ -158,10 +156,25 @@ class TestVariableManager(unittest.TestCase): # and role2 depend on common-role. Check that the tasks see # different values of role_var. blocks = play1.compile() + + # compile returns the following layout of blocks + # (fact gathering is missing as that is added by PlayIterator): + # [TASK: meta (flush_handlers)] + # [TASK: common-role : debug] + # [TASK: meta (role_complete)] + # [TASK: meta (role_complete)] + # [TASK: common-role : debug] + # [TASK: meta (role_complete)] + # [TASK: meta (role_complete)] + # [TASK: meta (flush_handlers)] + # [TASK: meta (flush_handlers)] + task = blocks[1].block[0] + assert task.action == 'debug' res = v.get_vars(play=play1, task=task) - self.assertEqual(res['role_var'], 'role_var_from_role1') + assert res['role_var'] == 'role_var_from_role1' - task = blocks[2].block[0] + task = blocks[4].block[0] + assert task.action == 'debug' res = v.get_vars(play=play1, task=task) - self.assertEqual(res['role_var'], 'role_var_from_role2') + assert res['role_var'] == 'role_var_from_role2'