`FieldAttribute`s as descriptors (#73908)

pull/78170/head
Martin Krizek 3 years ago committed by GitHub
parent 4c9385dab7
commit 43153c5831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,7 +50,7 @@ def extract_keywords(keyword_definitions):
# Maintain order of the actual class names for our output # Maintain order of the actual class names for our output
# Build up a mapping of playbook classes to the attributes that they hold # Build up a mapping of playbook classes to the attributes that they hold
pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class._valid_attrs.items() pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class.fattributes.items()
# Filter private attributes as they're not usable in playbooks # Filter private attributes as they're not usable in playbooks
if not v.private} if not v.private}
@ -60,7 +60,7 @@ def extract_keywords(keyword_definitions):
pb_keywords[pb_class_name][keyword] = keyword_definitions[keyword] pb_keywords[pb_class_name][keyword] = keyword_definitions[keyword]
else: else:
# check if there is an alias, otherwise undocumented # check if there is an alias, otherwise undocumented
alias = getattr(getattr(playbook_class, '_%s' % keyword), 'alias', None) alias = getattr(playbook_class.fattributes.get(keyword), 'alias', None)
if alias and alias in keyword_definitions: if alias and alias in keyword_definitions:
pb_keywords[pb_class_name][alias] = keyword_definitions[alias] pb_keywords[pb_class_name][alias] = keyword_definitions[alias]
del pb_keywords[pb_class_name][keyword] del pb_keywords[pb_class_name][keyword]

@ -590,13 +590,13 @@ class DocCLI(CLI, RoleMixin):
loaded_class = importlib.import_module(obj_class) loaded_class = importlib.import_module(obj_class)
PB_LOADED[pobj] = getattr(loaded_class, pobj, None) PB_LOADED[pobj] = getattr(loaded_class, pobj, None)
if keyword in PB_LOADED[pobj]._valid_attrs: if keyword in PB_LOADED[pobj].fattributes:
kdata['applies_to'].append(pobj) kdata['applies_to'].append(pobj)
# we should only need these once # we should only need these once
if 'type' not in kdata: if 'type' not in kdata:
fa = getattr(PB_LOADED[pobj], '_%s' % keyword) fa = PB_LOADED[pobj].fattributes.get(keyword)
if getattr(fa, 'private'): if getattr(fa, 'private'):
kdata = {} kdata = {}
raise KeyError raise KeyError

@ -114,8 +114,8 @@ class ModuleArgsParser:
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.playbook.handler import Handler from ansible.playbook.handler import Handler
# store the valid Task/Handler attrs for quick access # store the valid Task/Handler attrs for quick access
self._task_attrs = set(Task._valid_attrs.keys()) self._task_attrs = set(Task.fattributes)
self._task_attrs.update(set(Handler._valid_attrs.keys())) self._task_attrs.update(set(Handler.fattributes))
# HACK: why are these not FieldAttributes on task with a post-validate to check usage? # HACK: why are these not FieldAttributes on task with a post-validate to check usage?
self._task_attrs.update(['local_action', 'static']) self._task_attrs.update(['local_action', 'static'])
self._task_attrs = frozenset(self._task_attrs) self._task_attrs = frozenset(self._task_attrs)

@ -21,6 +21,7 @@ __metaclass__ = type
from copy import copy, deepcopy from copy import copy, deepcopy
from ansible.utils.sentinel import Sentinel
_CONTAINERS = frozenset(('list', 'dict', 'set')) _CONTAINERS = frozenset(('list', 'dict', 'set'))
@ -37,10 +38,7 @@ class Attribute:
priority=0, priority=0,
class_type=None, class_type=None,
always_post_validate=False, always_post_validate=False,
inherit=True,
alias=None, alias=None,
extend=False,
prepend=False,
static=False, static=False,
): ):
@ -70,9 +68,6 @@ class Attribute:
the field will be an instance of that class. the field will be an instance of that class.
:kwarg always_post_validate: Controls whether a field should be post :kwarg always_post_validate: Controls whether a field should be post
validated or not (default: False). validated or not (default: False).
:kwarg inherit: A boolean value, which controls whether the object
containing this field should attempt to inherit the value from its
parent object if the local value is None.
:kwarg alias: An alias to use for the attribute name, for situations where :kwarg alias: An alias to use for the attribute name, for situations where
the attribute name may conflict with a Python reserved word. the attribute name may conflict with a Python reserved word.
""" """
@ -85,15 +80,15 @@ class Attribute:
self.priority = priority self.priority = priority
self.class_type = class_type self.class_type = class_type
self.always_post_validate = always_post_validate self.always_post_validate = always_post_validate
self.inherit = inherit
self.alias = alias self.alias = alias
self.extend = extend
self.prepend = prepend
self.static = static self.static = static
if default is not None and self.isa in _CONTAINERS and not callable(default): if default is not None and self.isa in _CONTAINERS and not callable(default):
raise TypeError('defaults for FieldAttribute may not be mutable, please provide a callable instead') raise TypeError('defaults for FieldAttribute may not be mutable, please provide a callable instead')
def __set_name__(self, owner, name):
self.name = name
def __eq__(self, other): def __eq__(self, other):
return other.priority == self.priority return other.priority == self.priority
@ -114,6 +109,94 @@ class Attribute:
def __ge__(self, other): def __ge__(self, other):
return other.priority >= self.priority return other.priority >= self.priority
def __get__(self, obj, obj_type=None):
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not used in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)
return value
def __set__(self, obj, value):
setattr(obj, f'_{self.name}', value)
if self.alias is not None:
setattr(obj, f'_{self.alias}', value)
# NOTE this appears to be not needed in the codebase,
# leaving it here for test_attr_int_del from
# test/units/playbook/test_base.py to pass.
def __delete__(self, obj):
delattr(obj, f'_{self.name}')
class NonInheritableFieldAttribute(Attribute):
...
class FieldAttribute(Attribute): class FieldAttribute(Attribute):
pass def __init__(self, extend=False, prepend=False, **kwargs):
super().__init__(**kwargs)
self.extend = extend
self.prepend = prepend
def __get__(self, obj, obj_type=None):
if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
try:
value = obj._get_parent_attribute(self.name)
except AttributeError:
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not needed in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)
return value
class ConnectionFieldAttribute(FieldAttribute):
def __get__(self, obj, obj_type=None):
from ansible.module_utils.compat.paramiko import paramiko
from ansible.utils.ssh_functions import check_for_controlpersist
value = super().__get__(obj, obj_type)
if value == 'smart':
value = 'ssh'
# see if SSH can support ControlPersist if not use paramiko
if not check_for_controlpersist('ssh') and paramiko is not None:
value = "paramiko"
# if someone did `connection: persistent`, default it to using a persistent paramiko connection to avoid problems
elif value == 'persistent' and paramiko is not None:
value = 'paramiko'
return value

@ -10,7 +10,6 @@ import operator
import os import os
from copy import copy as shallowcopy from copy import copy as shallowcopy
from functools import partial
from jinja2.exceptions import UndefinedError from jinja2.exceptions import UndefinedError
@ -21,7 +20,7 @@ from ansible.module_utils.six import string_types
from ansible.module_utils.parsing.convert_bool import boolean from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils._text import to_text, to_native from ansible.module_utils._text import to_text, to_native
from ansible.parsing.dataloader import DataLoader from ansible.parsing.dataloader import DataLoader
from ansible.playbook.attribute import Attribute, FieldAttribute from ansible.playbook.attribute import Attribute, FieldAttribute, ConnectionFieldAttribute, NonInheritableFieldAttribute
from ansible.plugins.loader import module_loader, action_loader from ansible.plugins.loader import module_loader, action_loader
from ansible.utils.collection_loader._collection_finder import _get_collection_metadata, AnsibleCollectionRef from ansible.utils.collection_loader._collection_finder import _get_collection_metadata, AnsibleCollectionRef
from ansible.utils.display import Display from ansible.utils.display import Display
@ -31,54 +30,6 @@ from ansible.utils.vars import combine_vars, isidentifier, get_unique_id
display = Display() display = Display()
def _generic_g(prop_name, self):
try:
value = self._attributes[prop_name]
except KeyError:
raise AttributeError("'%s' does not have the keyword '%s'" % (self.__class__.__name__, prop_name))
if value is Sentinel:
value = self._attr_defaults[prop_name]
return value
def _generic_g_method(prop_name, self):
try:
if self._squashed:
return self._attributes[prop_name]
method = "_get_attr_%s" % prop_name
return getattr(self, method)()
except KeyError:
raise AttributeError("'%s' does not support the keyword '%s'" % (self.__class__.__name__, prop_name))
def _generic_g_parent(prop_name, self):
try:
if self._squashed or self._finalized:
value = self._attributes[prop_name]
else:
try:
value = self._get_parent_attribute(prop_name)
except AttributeError:
value = self._attributes[prop_name]
except KeyError:
raise AttributeError("'%s' nor it's parents support the keyword '%s'" % (self.__class__.__name__, prop_name))
if value is Sentinel:
value = self._attr_defaults[prop_name]
return value
def _generic_s(prop_name, self, value):
self._attributes[prop_name] = value
def _generic_d(prop_name, self):
del self._attributes[prop_name]
def _validate_action_group_metadata(action, found_group_metadata, fq_group_name): def _validate_action_group_metadata(action, found_group_metadata, fq_group_name):
valid_metadata = { valid_metadata = {
'extend_group': { 'extend_group': {
@ -118,83 +69,30 @@ def _validate_action_group_metadata(action, found_group_metadata, fq_group_name)
display.warning(" ".join(metadata_warnings)) display.warning(" ".join(metadata_warnings))
class BaseMeta(type): # FIXME use @property and @classmethod together which is possible since Python 3.9
class _FABMeta(type):
"""
Metaclass for the Base object, which is used to construct the class
attributes based on the FieldAttributes available.
"""
def __new__(cls, name, parents, dct): @property
def _create_attrs(src_dict, dst_dict): def fattributes(cls):
''' # FIXME is this worth caching?
Helper method which creates the attributes based on those in the fattributes = {}
source dictionary of attributes. This also populates the other for class_obj in reversed(cls.__mro__):
attributes used to keep track of these attributes and via the for name, attr in list(class_obj.__dict__.items()):
getter/setter/deleter methods. if not isinstance(attr, Attribute):
''' continue
keys = list(src_dict.keys()) fattributes[name] = attr
for attr_name in keys: if attr.alias:
value = src_dict[attr_name] setattr(class_obj, attr.alias, attr)
if isinstance(value, Attribute): fattributes[attr.alias] = attr
if attr_name.startswith('_'): return fattributes
attr_name = attr_name[1:]
# here we selectively assign the getter based on a few
# things, such as whether we have a _get_attr_<name>
# method, or if the attribute is marked as not inheriting
# its value from a parent object
method = "_get_attr_%s" % attr_name
try:
if method in src_dict or method in dst_dict:
getter = partial(_generic_g_method, attr_name)
elif ('_get_parent_attribute' in dst_dict or '_get_parent_attribute' in src_dict) and value.inherit:
getter = partial(_generic_g_parent, attr_name)
else:
getter = partial(_generic_g, attr_name)
except AttributeError as e:
raise AnsibleParserError("Invalid playbook definition: %s" % to_native(e), orig_exc=e)
setter = partial(_generic_s, attr_name)
deleter = partial(_generic_d, attr_name)
dst_dict[attr_name] = property(getter, setter, deleter)
dst_dict['_valid_attrs'][attr_name] = value
dst_dict['_attributes'][attr_name] = Sentinel
dst_dict['_attr_defaults'][attr_name] = value.default
if value.alias is not None:
dst_dict[value.alias] = property(getter, setter, deleter)
dst_dict['_valid_attrs'][value.alias] = value
dst_dict['_alias_attrs'][value.alias] = attr_name
def _process_parents(parents, dst_dict):
'''
Helper method which creates attributes from all parent objects
recursively on through grandparent objects
'''
for parent in parents:
if hasattr(parent, '__dict__'):
_create_attrs(parent.__dict__, dst_dict)
new_dst_dict = parent.__dict__.copy()
new_dst_dict.update(dst_dict)
_process_parents(parent.__bases__, new_dst_dict)
# create some additional class attributes
dct['_attributes'] = {}
dct['_attr_defaults'] = {}
dct['_valid_attrs'] = {}
dct['_alias_attrs'] = {}
# now create the attributes based on the FieldAttributes
# available, including from parent (and grandparent) objects
_create_attrs(dct, dct)
_process_parents(parents, dct)
return super(BaseMeta, cls).__new__(cls, name, parents, dct)
class FieldAttributeBase(metaclass=_FABMeta):
class FieldAttributeBase(metaclass=BaseMeta): # FIXME use @property and @classmethod together which is possible since Python 3.9
@property
def fattributes(self):
return self.__class__.fattributes
def __init__(self): def __init__(self):
@ -211,17 +109,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
# every object gets a random uuid: # every object gets a random uuid:
self._uuid = get_unique_id() self._uuid = get_unique_id()
# we create a copy of the attributes here due to the fact that # init vars, avoid using defaults in field declaration as it lives across plays
# it was initialized as a class param in the meta class, so we
# need a unique object here (all members contained within are
# unique already).
self._attributes = self.__class__._attributes.copy()
self._attr_defaults = self.__class__._attr_defaults.copy()
for key, value in self._attr_defaults.items():
if callable(value):
self._attr_defaults[key] = value()
# and init vars, avoid using defaults in field declaration as it lives across plays
self.vars = dict() self.vars = dict()
@property @property
@ -273,17 +161,14 @@ class FieldAttributeBase(metaclass=BaseMeta):
# Walk all attributes in the class. We sort them based on their priority # Walk all attributes in the class. We sort them based on their priority
# so that certain fields can be loaded before others, if they are dependent. # so that certain fields can be loaded before others, if they are dependent.
for name, attr in sorted(self._valid_attrs.items(), key=operator.itemgetter(1)): for name, attr in sorted(self.fattributes.items(), key=operator.itemgetter(1)):
# copy the value over unless a _load_field method is defined # copy the value over unless a _load_field method is defined
target_name = name
if name in self._alias_attrs:
target_name = self._alias_attrs[name]
if name in ds: if name in ds:
method = getattr(self, '_load_%s' % name, None) method = getattr(self, '_load_%s' % name, None)
if method: if method:
self._attributes[target_name] = method(name, ds[name]) setattr(self, name, method(name, ds[name]))
else: else:
self._attributes[target_name] = ds[name] setattr(self, name, ds[name])
# run early, non-critical validation # run early, non-critical validation
self.validate() self.validate()
@ -316,7 +201,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
not map to attributes for this object. not map to attributes for this object.
''' '''
valid_attrs = frozenset(self._valid_attrs.keys()) valid_attrs = frozenset(self.fattributes)
for key in ds: for key in ds:
if key not in valid_attrs: if key not in valid_attrs:
raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds) raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds)
@ -327,18 +212,14 @@ class FieldAttributeBase(metaclass=BaseMeta):
if not self._validated: if not self._validated:
# walk all fields in the object # walk all fields in the object
for (name, attribute) in self._valid_attrs.items(): for (name, attribute) in self.fattributes.items():
if name in self._alias_attrs:
name = self._alias_attrs[name]
# run validator only if present # run validator only if present
method = getattr(self, '_validate_%s' % name, None) method = getattr(self, '_validate_%s' % name, None)
if method: if method:
method(attribute, name, getattr(self, name)) method(attribute, name, getattr(self, name))
else: else:
# and make sure the attribute is of the type it should be # and make sure the attribute is of the type it should be
value = self._attributes[name] value = getattr(self, name)
if value is not None: if value is not None:
if attribute.isa == 'string' and isinstance(value, (list, dict)): if attribute.isa == 'string' and isinstance(value, (list, dict)):
raise AnsibleParserError( raise AnsibleParserError(
@ -528,8 +409,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
parent attributes. parent attributes.
''' '''
if not self._squashed: if not self._squashed:
for name in self._valid_attrs.keys(): for name in self.fattributes:
self._attributes[name] = getattr(self, name) setattr(self, name, getattr(self, name))
self._squashed = True self._squashed = True
def copy(self): def copy(self):
@ -542,11 +423,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
except RuntimeError as e: except RuntimeError as e:
raise AnsibleError("Exceeded maximum object depth. This may have been caused by excessive role recursion", orig_exc=e) raise AnsibleError("Exceeded maximum object depth. This may have been caused by excessive role recursion", orig_exc=e)
for name in self._valid_attrs.keys(): for name in self.fattributes:
if name in self._alias_attrs: setattr(new_me, name, shallowcopy(getattr(self, f'_{name}', Sentinel)))
continue
new_me._attributes[name] = shallowcopy(self._attributes[name])
new_me._attr_defaults[name] = shallowcopy(self._attr_defaults[name])
new_me._loader = self._loader new_me._loader = self._loader
new_me._variable_manager = self._variable_manager new_me._variable_manager = self._variable_manager
@ -621,8 +499,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
# save the omit value for later checking # save the omit value for later checking
omit_value = templar.available_variables.get('omit') omit_value = templar.available_variables.get('omit')
for (name, attribute) in self._valid_attrs.items(): for (name, attribute) in self.fattributes.items():
if attribute.static: if attribute.static:
value = getattr(self, name) value = getattr(self, name)
@ -748,7 +625,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
Dumps all attributes to a dictionary Dumps all attributes to a dictionary
''' '''
attrs = {} attrs = {}
for (name, attribute) in self._valid_attrs.items(): for (name, attribute) in self.fattributes.items():
attr = getattr(self, name) attr = getattr(self, name)
if attribute.isa == 'class' and hasattr(attr, 'serialize'): if attribute.isa == 'class' and hasattr(attr, 'serialize'):
attrs[name] = attr.serialize() attrs[name] = attr.serialize()
@ -761,8 +638,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
Loads attributes from a dictionary Loads attributes from a dictionary
''' '''
for (attr, value) in attrs.items(): for (attr, value) in attrs.items():
if attr in self._valid_attrs: if attr in self.fattributes:
attribute = self._valid_attrs[attr] attribute = self.fattributes[attr]
if attribute.isa == 'class' and isinstance(value, dict): if attribute.isa == 'class' and isinstance(value, dict):
obj = attribute.class_type() obj = attribute.class_type()
obj.deserialize(value) obj.deserialize(value)
@ -806,12 +683,9 @@ class FieldAttributeBase(metaclass=BaseMeta):
if not isinstance(data, dict): if not isinstance(data, dict):
raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data))) raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
for (name, attribute) in self._valid_attrs.items(): for (name, attribute) in self.fattributes.items():
if name in data: if name in data:
setattr(self, name, data[name]) setattr(self, name, data[name])
else:
if callable(attribute.default):
setattr(self, name, attribute.default())
else: else:
setattr(self, name, attribute.default) setattr(self, name, attribute.default)
@ -823,40 +697,40 @@ class FieldAttributeBase(metaclass=BaseMeta):
class Base(FieldAttributeBase): class Base(FieldAttributeBase):
_name = FieldAttribute(isa='string', default='', always_post_validate=True, inherit=False) name = NonInheritableFieldAttribute(isa='string', default='', always_post_validate=True)
# connection/transport # connection/transport
_connection = FieldAttribute(isa='string', default=context.cliargs_deferred_get('connection')) connection = ConnectionFieldAttribute(isa='string', default=context.cliargs_deferred_get('connection'))
_port = FieldAttribute(isa='int') port = FieldAttribute(isa='int')
_remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user')) remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user'))
# variables # variables
_vars = FieldAttribute(isa='dict', priority=100, inherit=False, static=True) vars = NonInheritableFieldAttribute(isa='dict', priority=100, static=True)
# module default params # module default params
_module_defaults = FieldAttribute(isa='list', extend=True, prepend=True) module_defaults = FieldAttribute(isa='list', extend=True, prepend=True)
# flags and misc. settings # flags and misc. settings
_environment = FieldAttribute(isa='list', extend=True, prepend=True) environment = FieldAttribute(isa='list', extend=True, prepend=True)
_no_log = FieldAttribute(isa='bool') no_log = FieldAttribute(isa='bool')
_run_once = FieldAttribute(isa='bool') run_once = FieldAttribute(isa='bool')
_ignore_errors = FieldAttribute(isa='bool') ignore_errors = FieldAttribute(isa='bool')
_ignore_unreachable = FieldAttribute(isa='bool') ignore_unreachable = FieldAttribute(isa='bool')
_check_mode = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('check')) check_mode = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('check'))
_diff = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('diff')) diff = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('diff'))
_any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL) any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL)
_throttle = FieldAttribute(isa='int', default=0) throttle = FieldAttribute(isa='int', default=0)
_timeout = FieldAttribute(isa='int', default=C.TASK_TIMEOUT) timeout = FieldAttribute(isa='int', default=C.TASK_TIMEOUT)
# explicitly invoke a debugger on tasks # explicitly invoke a debugger on tasks
_debugger = FieldAttribute(isa='string') debugger = FieldAttribute(isa='string')
# Privilege escalation # Privilege escalation
_become = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('become')) become = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('become'))
_become_method = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_method')) become_method = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_method'))
_become_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_user')) become_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_user'))
_become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags')) become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
_become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe')) become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
# used to hold sudo/su stuff # used to hold sudo/su stuff
DEPRECATED_ATTRIBUTES = [] # type: list[str] DEPRECATED_ATTRIBUTES = [] # type: list[str]

@ -21,7 +21,7 @@ __metaclass__ = type
import ansible.constants as C import ansible.constants as C
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.collectionsearch import CollectionSearch
@ -34,18 +34,18 @@ from ansible.utils.sentinel import Sentinel
class Block(Base, Conditional, CollectionSearch, Taggable): class Block(Base, Conditional, CollectionSearch, Taggable):
# main block fields containing the task lists # main block fields containing the task lists
_block = FieldAttribute(isa='list', default=list, inherit=False) block = NonInheritableFieldAttribute(isa='list', default=list)
_rescue = FieldAttribute(isa='list', default=list, inherit=False) rescue = NonInheritableFieldAttribute(isa='list', default=list)
_always = FieldAttribute(isa='list', default=list, inherit=False) always = NonInheritableFieldAttribute(isa='list', default=list)
# other fields for task compat # other fields for task compat
_notify = FieldAttribute(isa='list') notify = FieldAttribute(isa='list')
_delegate_to = FieldAttribute(isa='string') delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool') delegate_facts = FieldAttribute(isa='bool')
# for future consideration? this would be functionally # for future consideration? this would be functionally
# similar to the 'else' clause for exceptions # similar to the 'else' clause for exceptions
# _otherwise = FieldAttribute(isa='list') # otherwise = FieldAttribute(isa='list')
def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False): def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False):
self._play = play self._play = play
@ -230,7 +230,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
''' '''
data = dict() data = dict()
for attr in self._valid_attrs: for attr in self.fattributes:
if attr not in ('block', 'rescue', 'always'): if attr not in ('block', 'rescue', 'always'):
data[attr] = getattr(self, attr) data[attr] = getattr(self, attr)
@ -256,7 +256,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
# we don't want the full set of attributes (the task lists), as that # we don't want the full set of attributes (the task lists), as that
# would lead to a serialize/deserialize loop # would lead to a serialize/deserialize loop
for attr in self._valid_attrs: for attr in self.fattributes:
if attr in data and attr not in ('block', 'rescue', 'always'): if attr in data and attr not in ('block', 'rescue', 'always'):
setattr(self, attr, data.get(attr)) setattr(self, attr, data.get(attr))
@ -298,11 +298,10 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
''' '''
Generic logic to get the attribute or parent attribute for a block value. Generic logic to get the attribute or parent attribute for a block value.
''' '''
extend = self.fattributes.get(attr).extend
extend = self._valid_attrs[attr].extend prepend = self.fattributes.get(attr).prepend
prepend = self._valid_attrs[attr].prepend
try: try:
value = self._attributes[attr] value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent # If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent # otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True): if getattr(self._parent, 'statically_loaded', True):
@ -316,7 +315,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
if hasattr(_parent, '_get_parent_attribute'): if hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr) parent_value = _parent._get_parent_attribute(attr)
else: else:
parent_value = _parent._attributes.get(attr, Sentinel) parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend: if extend:
value = self._extend_value(value, parent_value, prepend) value = self._extend_value(value, parent_value, prepend)
else: else:
@ -325,7 +324,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
pass pass
if self._role and (value is Sentinel or extend): if self._role and (value is Sentinel or extend):
try: try:
parent_value = self._role._attributes.get(attr, Sentinel) parent_value = getattr(self._role, f'_{attr}', Sentinel)
if extend: if extend:
value = self._extend_value(value, parent_value, prepend) value = self._extend_value(value, parent_value, prepend)
else: else:
@ -335,7 +334,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
if dep_chain and (value is Sentinel or extend): if dep_chain and (value is Sentinel or extend):
dep_chain.reverse() dep_chain.reverse()
for dep in dep_chain: for dep in dep_chain:
dep_value = dep._attributes.get(attr, Sentinel) dep_value = getattr(dep, f'_{attr}', Sentinel)
if extend: if extend:
value = self._extend_value(value, dep_value, prepend) value = self._extend_value(value, dep_value, prepend)
else: else:
@ -347,7 +346,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
pass pass
if self._play and (value is Sentinel or extend): if self._play and (value is Sentinel or extend):
try: try:
play_value = self._play._attributes.get(attr, Sentinel) play_value = getattr(self._play, f'_{attr}', Sentinel)
if play_value is not Sentinel: if play_value is not Sentinel:
if extend: if extend:
value = self._extend_value(value, play_value, prepend) value = self._extend_value(value, play_value, prepend)

@ -36,13 +36,13 @@ def _ensure_default_collection(collection_list=None):
class CollectionSearch: class CollectionSearch:
# this needs to be populated before we can resolve tasks/roles/etc # this needs to be populated before we can resolve tasks/roles/etc
_collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection, collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection,
always_post_validate=True, static=True) always_post_validate=True, static=True)
def _load_collections(self, attr, ds): def _load_collections(self, attr, ds):
# We are always a mixin with Base, so we can validate this untemplated # We are always a mixin with Base, so we can validate this untemplated
# field early on to guarantee we are dealing with a list. # field early on to guarantee we are dealing with a list.
ds = self.get_validated_value('collections', self._collections, ds, None) ds = self.get_validated_value('collections', self.fattributes.get('collections'), ds, None)
# this will only be called if someone specified a value; call the shared value # this will only be called if someone specified a value; call the shared value
_ensure_default_collection(collection_list=ds) _ensure_default_collection(collection_list=ds)

@ -46,7 +46,7 @@ class Conditional:
to be run conditionally when a condition is met or skipped. to be run conditionally when a condition is met or skipped.
''' '''
_when = FieldAttribute(isa='list', default=list, extend=True, prepend=True) when = FieldAttribute(isa='list', default=list, extend=True, prepend=True)
def __init__(self, loader=None): def __init__(self, loader=None):
# when used directly, this class needs a loader, but we want to # when used directly, this class needs a loader, but we want to

@ -26,7 +26,7 @@ from ansible.module_utils.six import string_types
class Handler(Task): class Handler(Task):
_listen = FieldAttribute(isa='list', default=list, listof=string_types, static=True) listen = FieldAttribute(isa='list', default=list, listof=string_types, static=True)
def __init__(self, block=None, role=None, task_include=None): def __init__(self, block=None, role=None, task_include=None):
self.notified_hosts = [] self.notified_hosts = []

@ -25,12 +25,12 @@ from ansible.playbook.base import FieldAttributeBase
class LoopControl(FieldAttributeBase): class LoopControl(FieldAttributeBase):
_loop_var = FieldAttribute(isa='str', default='item') loop_var = FieldAttribute(isa='str', default='item')
_index_var = FieldAttribute(isa='str') index_var = FieldAttribute(isa='str')
_label = FieldAttribute(isa='str') label = FieldAttribute(isa='str')
_pause = FieldAttribute(isa='float', default=0) pause = FieldAttribute(isa='float', default=0)
_extended = FieldAttribute(isa='bool') extended = FieldAttribute(isa='bool')
_extended_allitems = FieldAttribute(isa='bool', default=True) extended_allitems = FieldAttribute(isa='bool', default=True)
def __init__(self): def __init__(self):
super(LoopControl, self).__init__() super(LoopControl, self).__init__()

@ -54,35 +54,35 @@ class Play(Base, Taggable, CollectionSearch):
""" """
# ================================================================================= # =================================================================================
_hosts = FieldAttribute(isa='list', required=True, listof=string_types, always_post_validate=True, priority=-1) hosts = FieldAttribute(isa='list', required=True, listof=string_types, always_post_validate=True, priority=-2)
# Facts # Facts
_gather_facts = FieldAttribute(isa='bool', default=None, always_post_validate=True) gather_facts = FieldAttribute(isa='bool', default=None, always_post_validate=True)
# defaults to be deprecated, should be 'None' in future # defaults to be deprecated, should be 'None' in future
_gather_subset = FieldAttribute(isa='list', default=(lambda: C.DEFAULT_GATHER_SUBSET), listof=string_types, always_post_validate=True) gather_subset = FieldAttribute(isa='list', default=(lambda: C.DEFAULT_GATHER_SUBSET), listof=string_types, always_post_validate=True)
_gather_timeout = FieldAttribute(isa='int', default=C.DEFAULT_GATHER_TIMEOUT, always_post_validate=True) gather_timeout = FieldAttribute(isa='int', default=C.DEFAULT_GATHER_TIMEOUT, always_post_validate=True)
_fact_path = FieldAttribute(isa='string', default=C.DEFAULT_FACT_PATH) fact_path = FieldAttribute(isa='string', default=C.DEFAULT_FACT_PATH)
# Variable Attributes # Variable Attributes
_vars_files = FieldAttribute(isa='list', default=list, priority=99) vars_files = FieldAttribute(isa='list', default=list, priority=99)
_vars_prompt = FieldAttribute(isa='list', default=list, always_post_validate=False) vars_prompt = FieldAttribute(isa='list', default=list, always_post_validate=False)
# Role Attributes # Role Attributes
_roles = FieldAttribute(isa='list', default=list, priority=90) roles = FieldAttribute(isa='list', default=list, priority=90)
# Block (Task) Lists Attributes # Block (Task) Lists Attributes
_handlers = FieldAttribute(isa='list', default=list) handlers = FieldAttribute(isa='list', default=list, priority=-1)
_pre_tasks = FieldAttribute(isa='list', default=list) pre_tasks = FieldAttribute(isa='list', default=list, priority=-1)
_post_tasks = FieldAttribute(isa='list', default=list) post_tasks = FieldAttribute(isa='list', default=list, priority=-1)
_tasks = FieldAttribute(isa='list', default=list) tasks = FieldAttribute(isa='list', default=list, priority=-1)
# Flag/Setting Attributes # Flag/Setting Attributes
_force_handlers = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('force_handlers'), always_post_validate=True) force_handlers = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('force_handlers'), always_post_validate=True)
_max_fail_percentage = FieldAttribute(isa='percent', always_post_validate=True) max_fail_percentage = FieldAttribute(isa='percent', always_post_validate=True)
_serial = FieldAttribute(isa='list', default=list, always_post_validate=True) serial = FieldAttribute(isa='list', default=list, always_post_validate=True)
_strategy = FieldAttribute(isa='string', default=C.DEFAULT_STRATEGY, always_post_validate=True) strategy = FieldAttribute(isa='string', default=C.DEFAULT_STRATEGY, always_post_validate=True)
_order = FieldAttribute(isa='string', always_post_validate=True) order = FieldAttribute(isa='string', always_post_validate=True)
# ================================================================================= # =================================================================================

@ -77,46 +77,46 @@ class PlayContext(Base):
''' '''
# base # base
_module_compression = FieldAttribute(isa='string', default=C.DEFAULT_MODULE_COMPRESSION) module_compression = FieldAttribute(isa='string', default=C.DEFAULT_MODULE_COMPRESSION)
_shell = FieldAttribute(isa='string') shell = FieldAttribute(isa='string')
_executable = FieldAttribute(isa='string', default=C.DEFAULT_EXECUTABLE) executable = FieldAttribute(isa='string', default=C.DEFAULT_EXECUTABLE)
# connection fields, some are inherited from Base: # connection fields, some are inherited from Base:
# (connection, port, remote_user, environment, no_log) # (connection, port, remote_user, environment, no_log)
_remote_addr = FieldAttribute(isa='string') remote_addr = FieldAttribute(isa='string')
_password = FieldAttribute(isa='string') password = FieldAttribute(isa='string')
_timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT) timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT)
_connection_user = FieldAttribute(isa='string') connection_user = FieldAttribute(isa='string')
_private_key_file = FieldAttribute(isa='string', default=C.DEFAULT_PRIVATE_KEY_FILE) private_key_file = FieldAttribute(isa='string', default=C.DEFAULT_PRIVATE_KEY_FILE)
_pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_PIPELINING) pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_PIPELINING)
# networking modules # networking modules
_network_os = FieldAttribute(isa='string') network_os = FieldAttribute(isa='string')
# docker FIXME: remove these # docker FIXME: remove these
_docker_extra_args = FieldAttribute(isa='string') docker_extra_args = FieldAttribute(isa='string')
# ??? # ???
_connection_lockfd = FieldAttribute(isa='int') connection_lockfd = FieldAttribute(isa='int')
# privilege escalation fields # privilege escalation fields
_become = FieldAttribute(isa='bool') become = FieldAttribute(isa='bool')
_become_method = FieldAttribute(isa='string') become_method = FieldAttribute(isa='string')
_become_user = FieldAttribute(isa='string') become_user = FieldAttribute(isa='string')
_become_pass = FieldAttribute(isa='string') become_pass = FieldAttribute(isa='string')
_become_exe = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_EXE) become_exe = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_EXE)
_become_flags = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_FLAGS) become_flags = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_FLAGS)
_prompt = FieldAttribute(isa='string') prompt = FieldAttribute(isa='string')
# general flags # general flags
_only_tags = FieldAttribute(isa='set', default=set) only_tags = FieldAttribute(isa='set', default=set)
_skip_tags = FieldAttribute(isa='set', default=set) skip_tags = FieldAttribute(isa='set', default=set)
_start_at_task = FieldAttribute(isa='string') start_at_task = FieldAttribute(isa='string')
_step = FieldAttribute(isa='bool', default=False) step = FieldAttribute(isa='bool', default=False)
# "PlayContext.force_handlers should not be used, the calling code should be using play itself instead" # "PlayContext.force_handlers should not be used, the calling code should be using play itself instead"
_force_handlers = FieldAttribute(isa='bool', default=False) force_handlers = FieldAttribute(isa='bool', default=False)
@property @property
def verbosity(self): def verbosity(self):
@ -353,21 +353,3 @@ class PlayContext(Base):
variables[var_opt] = var_val variables[var_opt] = var_val
except AttributeError: except AttributeError:
continue continue
def _get_attr_connection(self):
''' connections are special, this takes care of responding correctly '''
conn_type = None
if self._attributes['connection'] == 'smart':
conn_type = 'ssh'
# see if SSH can support ControlPersist if not use paramiko
if not check_for_controlpersist('ssh') and paramiko is not None:
conn_type = "paramiko"
# if someone did `connection: persistent`, default it to using a persistent paramiko connection to avoid problems
elif self._attributes['connection'] == 'persistent' and paramiko is not None:
conn_type = 'paramiko'
if conn_type:
self.connection = conn_type
return self._attributes['connection']

@ -41,8 +41,8 @@ display = Display()
class PlaybookInclude(Base, Conditional, Taggable): class PlaybookInclude(Base, Conditional, Taggable):
_import_playbook = FieldAttribute(isa='string') import_playbook = FieldAttribute(isa='string')
_vars = FieldAttribute(isa='dict', default=dict) vars_val = FieldAttribute(isa='dict', default=dict, alias='vars')
@staticmethod @staticmethod
def load(data, basedir, variable_manager=None, loader=None): def load(data, basedir, variable_manager=None, loader=None):
@ -120,7 +120,7 @@ class PlaybookInclude(Base, Conditional, Taggable):
# those attached to each block (if any) # those attached to each block (if any)
if new_obj.when: if new_obj.when:
for task_block in (entry.pre_tasks + entry.roles + entry.tasks + entry.post_tasks): for task_block in (entry.pre_tasks + entry.roles + entry.tasks + entry.post_tasks):
task_block._attributes['when'] = new_obj.when[:] + task_block.when[:] task_block._when = new_obj.when[:] + task_block.when[:]
return pb return pb

@ -37,6 +37,7 @@ from ansible.playbook.taggable import Taggable
from ansible.plugins.loader import add_all_plugin_dirs from ansible.plugins.loader import add_all_plugin_dirs
from ansible.utils.collection_loader import AnsibleCollectionConfig from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.path import is_subpath from ansible.utils.path import is_subpath
from ansible.utils.sentinel import Sentinel
from ansible.utils.vars import combine_vars from ansible.utils.vars import combine_vars
__all__ = ['Role', 'hash_params'] __all__ = ['Role', 'hash_params']
@ -97,8 +98,8 @@ def hash_params(params):
class Role(Base, Conditional, Taggable, CollectionSearch): class Role(Base, Conditional, Taggable, CollectionSearch):
_delegate_to = FieldAttribute(isa='string') delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool') delegate_facts = FieldAttribute(isa='bool')
def __init__(self, play=None, from_files=None, from_include=False, validate=True): def __init__(self, play=None, from_files=None, from_include=False, validate=True):
self._role_name = None self._role_name = None
@ -198,15 +199,19 @@ class Role(Base, Conditional, Taggable, CollectionSearch):
self.add_parent(parent_role) self.add_parent(parent_role)
# copy over all field attributes from the RoleInclude # copy over all field attributes from the RoleInclude
# update self._attributes directly, to avoid squashing # update self._attr directly, to avoid squashing
for (attr_name, dump) in self._valid_attrs.items(): for attr_name in self.fattributes:
if attr_name in ('when', 'tags'): if attr_name in ('when', 'tags'):
self._attributes[attr_name] = self._extend_value( setattr(
self._attributes[attr_name], self,
role_include._attributes[attr_name], f'_{attr_name}',
self._extend_value(
getattr(self, f'_{attr_name}', Sentinel),
getattr(role_include, f'_{attr_name}', Sentinel),
)
) )
else: else:
self._attributes[attr_name] = role_include._attributes[attr_name] setattr(self, f'_{attr_name}', getattr(role_include, f'_{attr_name}', Sentinel))
# vars and default vars are regular dictionaries # vars and default vars are regular dictionaries
self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True) self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True)

@ -43,7 +43,7 @@ display = Display()
class RoleDefinition(Base, Conditional, Taggable, CollectionSearch): class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
_role = FieldAttribute(isa='string') role = FieldAttribute(isa='string')
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None): def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
@ -210,7 +210,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
role_def = dict() role_def = dict()
role_params = dict() role_params = dict()
base_attribute_names = frozenset(self._valid_attrs.keys()) base_attribute_names = frozenset(self.fattributes)
for (key, value) in ds.items(): for (key, value) in ds.items():
# use the list of FieldAttribute values to determine what is and is not # use the list of FieldAttribute values to determine what is and is not
# an extra parameter for this role (or sub-class of this role) # an extra parameter for this role (or sub-class of this role)

@ -37,8 +37,8 @@ class RoleInclude(RoleDefinition):
is included for execution in a play. is included for execution in a play.
""" """
_delegate_to = FieldAttribute(isa='string') delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool', default=False) delegate_facts = FieldAttribute(isa='bool', default=False)
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None): def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager, super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager,

@ -39,10 +39,10 @@ class RoleMetadata(Base, CollectionSearch):
within each Role (meta/main.yml). within each Role (meta/main.yml).
''' '''
_allow_duplicates = FieldAttribute(isa='bool', default=False) allow_duplicates = FieldAttribute(isa='bool', default=False)
_dependencies = FieldAttribute(isa='list', default=list) dependencies = FieldAttribute(isa='list', default=list)
_galaxy_info = FieldAttribute(isa='GalaxyInfo') galaxy_info = FieldAttribute(isa='GalaxyInfo')
_argument_specs = FieldAttribute(isa='dict', default=dict) argument_specs = FieldAttribute(isa='dict', default=dict)
def __init__(self, owner=None): def __init__(self, owner=None):
self._owner = owner self._owner = owner

@ -52,9 +52,9 @@ class IncludeRole(TaskInclude):
# ATTRIBUTES # ATTRIBUTES
# private as this is a 'module options' vs a task property # private as this is a 'module options' vs a task property
_allow_duplicates = FieldAttribute(isa='bool', default=True, private=True) allow_duplicates = FieldAttribute(isa='bool', default=True, private=True)
_public = FieldAttribute(isa='bool', default=False, private=True) public = FieldAttribute(isa='bool', default=False, private=True)
_rolespec_validate = FieldAttribute(isa='bool', default=True) rolespec_validate = FieldAttribute(isa='bool', default=True)
def __init__(self, block=None, role=None, task_include=None): def __init__(self, block=None, role=None, task_include=None):

@ -28,7 +28,7 @@ from ansible.template import Templar
class Taggable: class Taggable:
untagged = frozenset(['untagged']) untagged = frozenset(['untagged'])
_tags = FieldAttribute(isa='list', default=list, listof=(string_types, int), extend=True) tags = FieldAttribute(isa='list', default=list, listof=(string_types, int), extend=True)
def _load_tags(self, attr, ds): def _load_tags(self, attr, ds):
if isinstance(ds, list): if isinstance(ds, list):

@ -26,7 +26,7 @@ from ansible.module_utils.six import string_types
from ansible.parsing.mod_args import ModuleArgsParser from ansible.parsing.mod_args import ModuleArgsParser
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
from ansible.plugins.loader import lookup_loader from ansible.plugins.loader import lookup_loader
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.playbook.base import Base from ansible.playbook.base import Base
from ansible.playbook.block import Block from ansible.playbook.block import Block
from ansible.playbook.collectionsearch import CollectionSearch from ansible.playbook.collectionsearch import CollectionSearch
@ -63,28 +63,28 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
# might be possible to define others # might be possible to define others
# NOTE: ONLY set defaults on task attributes that are not inheritable, # NOTE: ONLY set defaults on task attributes that are not inheritable,
# inheritance is only triggered if the 'current value' is None, # inheritance is only triggered if the 'current value' is Sentinel,
# default can be set at play/top level object and inheritance will take it's course. # default can be set at play/top level object and inheritance will take it's course.
_args = FieldAttribute(isa='dict', default=dict) args = FieldAttribute(isa='dict', default=dict)
_action = FieldAttribute(isa='string') action = FieldAttribute(isa='string')
_async_val = FieldAttribute(isa='int', default=0, alias='async') async_val = FieldAttribute(isa='int', default=0, alias='async')
_changed_when = FieldAttribute(isa='list', default=list) changed_when = FieldAttribute(isa='list', default=list)
_delay = FieldAttribute(isa='int', default=5) delay = FieldAttribute(isa='int', default=5)
_delegate_to = FieldAttribute(isa='string') delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool') delegate_facts = FieldAttribute(isa='bool')
_failed_when = FieldAttribute(isa='list', default=list) failed_when = FieldAttribute(isa='list', default=list)
_loop = FieldAttribute() loop = FieldAttribute()
_loop_control = FieldAttribute(isa='class', class_type=LoopControl, inherit=False) loop_control = NonInheritableFieldAttribute(isa='class', class_type=LoopControl)
_notify = FieldAttribute(isa='list') notify = FieldAttribute(isa='list')
_poll = FieldAttribute(isa='int', default=C.DEFAULT_POLL_INTERVAL) poll = FieldAttribute(isa='int', default=C.DEFAULT_POLL_INTERVAL)
_register = FieldAttribute(isa='string', static=True) register = FieldAttribute(isa='string', static=True)
_retries = FieldAttribute(isa='int', default=3) retries = FieldAttribute(isa='int', default=3)
_until = FieldAttribute(isa='list', default=list) until = FieldAttribute(isa='list', default=list)
# deprecated, used to be loop and loop_args but loop has been repurposed # deprecated, used to be loop and loop_args but loop has been repurposed
_loop_with = FieldAttribute(isa='string', private=True, inherit=False) loop_with = NonInheritableFieldAttribute(isa='string', private=True)
def __init__(self, block=None, role=None, task_include=None): def __init__(self, block=None, role=None, task_include=None):
''' constructors a task, without the Task.load classmethod, it will be pretty blank ''' ''' constructors a task, without the Task.load classmethod, it will be pretty blank '''
@ -182,7 +182,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
else: else:
# Validate this untemplated field early on to guarantee we are dealing with a list. # Validate this untemplated field early on to guarantee we are dealing with a list.
# This is also done in CollectionSearch._load_collections() but this runs before that call. # This is also done in CollectionSearch._load_collections() but this runs before that call.
collections_list = self.get_validated_value('collections', self._collections, collections_list, None) collections_list = self.get_validated_value('collections', self.fattributes.get('collections'), collections_list, None)
if default_collection and not self._role: # FIXME: and not a collections role if default_collection and not self._role: # FIXME: and not a collections role
if collections_list: if collections_list:
@ -460,11 +460,10 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
''' '''
Generic logic to get the attribute or parent attribute for a task value. Generic logic to get the attribute or parent attribute for a task value.
''' '''
extend = self.fattributes.get(attr).extend
extend = self._valid_attrs[attr].extend prepend = self.fattributes.get(attr).prepend
prepend = self._valid_attrs[attr].prepend
try: try:
value = self._attributes[attr] value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent # If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent # otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True): if getattr(self._parent, 'statically_loaded', True):
@ -478,7 +477,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'): if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr) parent_value = _parent._get_parent_attribute(attr)
else: else:
parent_value = _parent._attributes.get(attr, Sentinel) parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend: if extend:
value = self._extend_value(value, parent_value, prepend) value = self._extend_value(value, parent_value, prepend)

@ -21,7 +21,6 @@ __metaclass__ = type
import ansible.constants as C import ansible.constants as C
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.block import Block from ansible.playbook.block import Block
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.utils.display import Display from ansible.utils.display import Display

@ -677,7 +677,7 @@ class StrategyBase:
continue continue
listeners = listening_handler.get_validated_value( listeners = listening_handler.get_validated_value(
'listen', listening_handler._valid_attrs['listen'], listeners, handler_templar 'listen', listening_handler.fattributes.get('listen'), listeners, handler_templar
) )
if handler_name not in listeners: if handler_name not in listeners:
continue continue

@ -39,14 +39,12 @@ def get_reserved_names(include_private=True):
class_list = [Play, Role, Block, Task] class_list = [Play, Role, Block, Task]
for aclass in class_list: for aclass in class_list:
aobj = aclass()
# build ordered list to loop over and dict with attributes # build ordered list to loop over and dict with attributes
for attribute in aobj.__dict__['_attributes']: for name, attr in aclass.fattributes.items():
if 'private' in attribute: if attr.private:
private.add(attribute) private.add(name)
else: else:
public.add(attribute) public.add(name)
# local_action is implicit with action # local_action is implicit with action
if 'action' in public: if 'action' in public:

@ -23,10 +23,11 @@ from units.compat import unittest
from ansible.errors import AnsibleParserError from ansible.errors import AnsibleParserError
from ansible.module_utils.six import string_types from ansible.module_utils.six import string_types
from ansible.playbook.attribute import FieldAttribute from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.template import Templar from ansible.template import Templar
from ansible.playbook import base from ansible.playbook import base
from ansible.utils.unsafe_proxy import AnsibleUnsafeBytes, AnsibleUnsafeText from ansible.utils.unsafe_proxy import AnsibleUnsafeBytes, AnsibleUnsafeText
from ansible.utils.sentinel import Sentinel
from units.mock.loader import DictDataLoader from units.mock.loader import DictDataLoader
@ -77,7 +78,7 @@ class TestBase(unittest.TestCase):
def _assert_copy(self, orig, copy): def _assert_copy(self, orig, copy):
self.assertIsInstance(copy, self.ClassUnderTest) self.assertIsInstance(copy, self.ClassUnderTest)
self.assertIsInstance(copy, base.Base) self.assertIsInstance(copy, base.Base)
self.assertEqual(len(orig._valid_attrs), len(copy._valid_attrs)) self.assertEqual(len(orig.fattributes), len(copy.fattributes))
sentinel = 'Empty DS' sentinel = 'Empty DS'
self.assertEqual(getattr(orig, '_ds', sentinel), getattr(copy, '_ds', sentinel)) self.assertEqual(getattr(orig, '_ds', sentinel), getattr(copy, '_ds', sentinel))
@ -107,8 +108,8 @@ class TestBase(unittest.TestCase):
d = self.ClassUnderTest() d = self.ClassUnderTest()
d.deserialize(data) d.deserialize(data)
self.assertIn('run_once', d._attributes) self.assertIn('_run_once', d.__dict__)
self.assertIn('check_mode', d._attributes) self.assertIn('_check_mode', d.__dict__)
data = {'no_log': False, data = {'no_log': False,
'remote_user': None, 'remote_user': None,
@ -122,9 +123,9 @@ class TestBase(unittest.TestCase):
d = self.ClassUnderTest() d = self.ClassUnderTest()
d.deserialize(data) d.deserialize(data)
self.assertNotIn('a_sentinel_with_an_unlikely_name', d._attributes) self.assertNotIn('_a_sentinel_with_an_unlikely_name', d.__dict__)
self.assertIn('run_once', d._attributes) self.assertIn('_run_once', d.__dict__)
self.assertIn('check_mode', d._attributes) self.assertIn('_check_mode', d.__dict__)
def test_serialize_then_deserialize(self): def test_serialize_then_deserialize(self):
ds = {'environment': [], ds = {'environment': [],
@ -165,7 +166,7 @@ class TestBase(unittest.TestCase):
# environment is supposed to be a list. This # environment is supposed to be a list. This
# seems like it shouldn't work? # seems like it shouldn't work?
ret = self.b.load_data(ds) ret = self.b.load_data(ds)
self.assertEqual(True, ret._attributes['environment']) self.assertEqual(True, ret._environment)
def test_post_validate(self): def test_post_validate(self):
ds = {'environment': [], ds = {'environment': [],
@ -312,7 +313,7 @@ class ExampleException(Exception):
# naming fails me... # naming fails me...
class ExampleParentBaseSubClass(base.Base): class ExampleParentBaseSubClass(base.Base):
_test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing') test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing')
def __init__(self): def __init__(self):
@ -324,8 +325,7 @@ class ExampleParentBaseSubClass(base.Base):
class ExampleSubClass(base.Base): class ExampleSubClass(base.Base):
_test_attr_blip = FieldAttribute(isa='string', default='example sub class test_attr_blip', test_attr_blip = NonInheritableFieldAttribute(isa='string', default='example sub class test_attr_blip',
inherit=False,
always_post_validate=True) always_post_validate=True)
def __init__(self): def __init__(self):
@ -339,31 +339,31 @@ class ExampleSubClass(base.Base):
class BaseSubClass(base.Base): class BaseSubClass(base.Base):
_name = FieldAttribute(isa='string', default='', always_post_validate=True) name = FieldAttribute(isa='string', default='', always_post_validate=True)
_test_attr_bool = FieldAttribute(isa='bool', always_post_validate=True) test_attr_bool = FieldAttribute(isa='bool', always_post_validate=True)
_test_attr_int = FieldAttribute(isa='int', always_post_validate=True) test_attr_int = FieldAttribute(isa='int', always_post_validate=True)
_test_attr_float = FieldAttribute(isa='float', default=3.14159, always_post_validate=True) test_attr_float = FieldAttribute(isa='float', default=3.14159, always_post_validate=True)
_test_attr_list = FieldAttribute(isa='list', listof=string_types, always_post_validate=True) test_attr_list = FieldAttribute(isa='list', listof=string_types, always_post_validate=True)
_test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True) test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True)
_test_attr_list_required = FieldAttribute(isa='list', listof=string_types, required=True, test_attr_list_required = FieldAttribute(isa='list', listof=string_types, required=True,
default=list, always_post_validate=True) default=list, always_post_validate=True)
_test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value') test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value')
_test_attr_string_required = FieldAttribute(isa='string', required=True, test_attr_string_required = FieldAttribute(isa='string', required=True,
default='the_test_attr_string_default_value') default='the_test_attr_string_default_value')
_test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True) test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True)
_test_attr_set = FieldAttribute(isa='set', default=set, always_post_validate=True) test_attr_set = FieldAttribute(isa='set', default=set, always_post_validate=True)
_test_attr_dict = FieldAttribute(isa='dict', default=lambda: {'a_key': 'a_value'}, always_post_validate=True) test_attr_dict = FieldAttribute(isa='dict', default=lambda: {'a_key': 'a_value'}, always_post_validate=True)
_test_attr_class = FieldAttribute(isa='class', class_type=ExampleSubClass) test_attr_class = FieldAttribute(isa='class', class_type=ExampleSubClass)
_test_attr_class_post_validate = FieldAttribute(isa='class', class_type=ExampleSubClass, test_attr_class_post_validate = FieldAttribute(isa='class', class_type=ExampleSubClass,
always_post_validate=True) always_post_validate=True)
_test_attr_unknown_isa = FieldAttribute(isa='not_a_real_isa', always_post_validate=True) test_attr_unknown_isa = FieldAttribute(isa='not_a_real_isa', always_post_validate=True)
_test_attr_example = FieldAttribute(isa='string', default='the_default', test_attr_example = FieldAttribute(isa='string', default='the_default',
always_post_validate=True) always_post_validate=True)
_test_attr_none = FieldAttribute(isa='string', always_post_validate=True) test_attr_none = FieldAttribute(isa='string', always_post_validate=True)
_test_attr_preprocess = FieldAttribute(isa='string', default='the default for preprocess') test_attr_preprocess = FieldAttribute(isa='string', default='the default for preprocess')
_test_attr_method = FieldAttribute(isa='string', default='some attr with a getter', test_attr_method = FieldAttribute(isa='string', default='some attr with a getter',
always_post_validate=True) always_post_validate=True)
_test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter', test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter',
always_post_validate=True) always_post_validate=True)
def _get_attr_test_attr_method(self): def _get_attr_test_attr_method(self):
@ -371,7 +371,7 @@ class BaseSubClass(base.Base):
def _validate_test_attr_example(self, attr, name, value): def _validate_test_attr_example(self, attr, name, value):
if not isinstance(value, str): if not isinstance(value, str):
raise ExampleException('_test_attr_example is not a string: %s type=%s' % (value, type(value))) raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value)))
def _post_validate_test_attr_example(self, attr, value, templar): def _post_validate_test_attr_example(self, attr, value, templar):
after_template_value = templar.template(value) after_template_value = templar.template(value)
@ -380,21 +380,6 @@ class BaseSubClass(base.Base):
def _post_validate_test_attr_none(self, attr, value, templar): def _post_validate_test_attr_none(self, attr, value, templar):
return None return None
def _get_parent_attribute(self, attr, extend=False, prepend=False):
value = None
try:
value = self._attributes[attr]
if self._parent and (value is None or extend):
parent_value = getattr(self._parent, attr, None)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except KeyError:
pass
return value
# terrible name, but it is a TestBase subclass for testing subclasses of Base # terrible name, but it is a TestBase subclass for testing subclasses of Base
class TestBaseSubClass(TestBase): class TestBaseSubClass(TestBase):
@ -420,7 +405,7 @@ class TestBaseSubClass(TestBase):
ds = {'test_attr_int': MOST_RANDOM_NUMBER} ds = {'test_attr_int': MOST_RANDOM_NUMBER}
bsc = self._base_validate(ds) bsc = self._base_validate(ds)
del bsc.test_attr_int del bsc.test_attr_int
self.assertNotIn('test_attr_int', bsc._attributes) self.assertNotIn('_test_attr_int', bsc.__dict__)
def test_attr_float(self): def test_attr_float(self):
roughly_pi = 4.0 roughly_pi = 4.0
@ -569,18 +554,18 @@ class TestBaseSubClass(TestBase):
string_list = ['foo', 'bar'] string_list = ['foo', 'bar']
ds = {'test_attr_list': string_list} ds = {'test_attr_list': string_list}
bsc = self._base_validate(ds) bsc = self._base_validate(ds)
self.assertEqual(string_list, bsc._attributes['test_attr_list']) self.assertEqual(string_list, bsc._test_attr_list)
def test_attr_list_none(self): def test_attr_list_none(self):
ds = {'test_attr_list': None} ds = {'test_attr_list': None}
bsc = self._base_validate(ds) bsc = self._base_validate(ds)
self.assertEqual(None, bsc._attributes['test_attr_list']) self.assertEqual(None, bsc._test_attr_list)
def test_attr_list_no_listof(self): def test_attr_list_no_listof(self):
test_list = ['foo', 'bar', 123] test_list = ['foo', 'bar', 123]
ds = {'test_attr_list_no_listof': test_list} ds = {'test_attr_list_no_listof': test_list}
bsc = self._base_validate(ds) bsc = self._base_validate(ds)
self.assertEqual(test_list, bsc._attributes['test_attr_list_no_listof']) self.assertEqual(test_list, bsc._test_attr_list_no_listof)
def test_attr_list_required(self): def test_attr_list_required(self):
string_list = ['foo', 'bar'] string_list = ['foo', 'bar']
@ -590,7 +575,7 @@ class TestBaseSubClass(TestBase):
fake_loader = DictDataLoader({}) fake_loader = DictDataLoader({})
templar = Templar(loader=fake_loader) templar = Templar(loader=fake_loader)
bsc.post_validate(templar) bsc.post_validate(templar)
self.assertEqual(string_list, bsc._attributes['test_attr_list_required']) self.assertEqual(string_list, bsc._test_attr_list_required)
def test_attr_list_required_empty_string(self): def test_attr_list_required_empty_string(self):
string_list = [""] string_list = [""]

@ -44,7 +44,7 @@ class MixinForMocks(object):
self.mock_play = MagicMock(name='MockPlay') self.mock_play = MagicMock(name='MockPlay')
self.mock_play._attributes = [] self.mock_play._attributes = []
self.mock_play.collections = None self.mock_play._collections = None
self.mock_iterator = MagicMock(name='MockIterator') self.mock_iterator = MagicMock(name='MockIterator')
self.mock_iterator._play = self.mock_play self.mock_iterator._play = self.mock_play

Loading…
Cancel
Save