`FieldAttribute`s as descriptors (#73908)

pull/76306/merge
Martin Krizek 2 years ago committed by GitHub
parent 4c9385dab7
commit 43153c5831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,7 +50,7 @@ def extract_keywords(keyword_definitions):
# Maintain order of the actual class names for our output
# Build up a mapping of playbook classes to the attributes that they hold
pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class._valid_attrs.items()
pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class.fattributes.items()
# Filter private attributes as they're not usable in playbooks
if not v.private}
@ -60,7 +60,7 @@ def extract_keywords(keyword_definitions):
pb_keywords[pb_class_name][keyword] = keyword_definitions[keyword]
else:
# check if there is an alias, otherwise undocumented
alias = getattr(getattr(playbook_class, '_%s' % keyword), 'alias', None)
alias = getattr(playbook_class.fattributes.get(keyword), 'alias', None)
if alias and alias in keyword_definitions:
pb_keywords[pb_class_name][alias] = keyword_definitions[alias]
del pb_keywords[pb_class_name][keyword]

@ -590,13 +590,13 @@ class DocCLI(CLI, RoleMixin):
loaded_class = importlib.import_module(obj_class)
PB_LOADED[pobj] = getattr(loaded_class, pobj, None)
if keyword in PB_LOADED[pobj]._valid_attrs:
if keyword in PB_LOADED[pobj].fattributes:
kdata['applies_to'].append(pobj)
# we should only need these once
if 'type' not in kdata:
fa = getattr(PB_LOADED[pobj], '_%s' % keyword)
fa = PB_LOADED[pobj].fattributes.get(keyword)
if getattr(fa, 'private'):
kdata = {}
raise KeyError

@ -114,8 +114,8 @@ class ModuleArgsParser:
from ansible.playbook.task import Task
from ansible.playbook.handler import Handler
# store the valid Task/Handler attrs for quick access
self._task_attrs = set(Task._valid_attrs.keys())
self._task_attrs.update(set(Handler._valid_attrs.keys()))
self._task_attrs = set(Task.fattributes)
self._task_attrs.update(set(Handler.fattributes))
# HACK: why are these not FieldAttributes on task with a post-validate to check usage?
self._task_attrs.update(['local_action', 'static'])
self._task_attrs = frozenset(self._task_attrs)

@ -21,6 +21,7 @@ __metaclass__ = type
from copy import copy, deepcopy
from ansible.utils.sentinel import Sentinel
_CONTAINERS = frozenset(('list', 'dict', 'set'))
@ -37,10 +38,7 @@ class Attribute:
priority=0,
class_type=None,
always_post_validate=False,
inherit=True,
alias=None,
extend=False,
prepend=False,
static=False,
):
@ -70,9 +68,6 @@ class Attribute:
the field will be an instance of that class.
:kwarg always_post_validate: Controls whether a field should be post
validated or not (default: False).
:kwarg inherit: A boolean value, which controls whether the object
containing this field should attempt to inherit the value from its
parent object if the local value is None.
:kwarg alias: An alias to use for the attribute name, for situations where
the attribute name may conflict with a Python reserved word.
"""
@ -85,15 +80,15 @@ class Attribute:
self.priority = priority
self.class_type = class_type
self.always_post_validate = always_post_validate
self.inherit = inherit
self.alias = alias
self.extend = extend
self.prepend = prepend
self.static = static
if default is not None and self.isa in _CONTAINERS and not callable(default):
raise TypeError('defaults for FieldAttribute may not be mutable, please provide a callable instead')
def __set_name__(self, owner, name):
self.name = name
def __eq__(self, other):
return other.priority == self.priority
@ -114,6 +109,94 @@ class Attribute:
def __ge__(self, other):
return other.priority >= self.priority
def __get__(self, obj, obj_type=None):
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not used in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)
return value
def __set__(self, obj, value):
setattr(obj, f'_{self.name}', value)
if self.alias is not None:
setattr(obj, f'_{self.alias}', value)
# NOTE this appears to be not needed in the codebase,
# leaving it here for test_attr_int_del from
# test/units/playbook/test_base.py to pass.
def __delete__(self, obj):
delattr(obj, f'_{self.name}')
class NonInheritableFieldAttribute(Attribute):
...
class FieldAttribute(Attribute):
pass
def __init__(self, extend=False, prepend=False, **kwargs):
super().__init__(**kwargs)
self.extend = extend
self.prepend = prepend
def __get__(self, obj, obj_type=None):
if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
try:
value = obj._get_parent_attribute(self.name)
except AttributeError:
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not needed in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)
if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)
return value
class ConnectionFieldAttribute(FieldAttribute):
def __get__(self, obj, obj_type=None):
from ansible.module_utils.compat.paramiko import paramiko
from ansible.utils.ssh_functions import check_for_controlpersist
value = super().__get__(obj, obj_type)
if value == 'smart':
value = 'ssh'
# see if SSH can support ControlPersist if not use paramiko
if not check_for_controlpersist('ssh') and paramiko is not None:
value = "paramiko"
# if someone did `connection: persistent`, default it to using a persistent paramiko connection to avoid problems
elif value == 'persistent' and paramiko is not None:
value = 'paramiko'
return value

@ -10,7 +10,6 @@ import operator
import os
from copy import copy as shallowcopy
from functools import partial
from jinja2.exceptions import UndefinedError
@ -21,7 +20,7 @@ from ansible.module_utils.six import string_types
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.module_utils._text import to_text, to_native
from ansible.parsing.dataloader import DataLoader
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.playbook.attribute import Attribute, FieldAttribute, ConnectionFieldAttribute, NonInheritableFieldAttribute
from ansible.plugins.loader import module_loader, action_loader
from ansible.utils.collection_loader._collection_finder import _get_collection_metadata, AnsibleCollectionRef
from ansible.utils.display import Display
@ -31,54 +30,6 @@ from ansible.utils.vars import combine_vars, isidentifier, get_unique_id
display = Display()
def _generic_g(prop_name, self):
try:
value = self._attributes[prop_name]
except KeyError:
raise AttributeError("'%s' does not have the keyword '%s'" % (self.__class__.__name__, prop_name))
if value is Sentinel:
value = self._attr_defaults[prop_name]
return value
def _generic_g_method(prop_name, self):
try:
if self._squashed:
return self._attributes[prop_name]
method = "_get_attr_%s" % prop_name
return getattr(self, method)()
except KeyError:
raise AttributeError("'%s' does not support the keyword '%s'" % (self.__class__.__name__, prop_name))
def _generic_g_parent(prop_name, self):
try:
if self._squashed or self._finalized:
value = self._attributes[prop_name]
else:
try:
value = self._get_parent_attribute(prop_name)
except AttributeError:
value = self._attributes[prop_name]
except KeyError:
raise AttributeError("'%s' nor it's parents support the keyword '%s'" % (self.__class__.__name__, prop_name))
if value is Sentinel:
value = self._attr_defaults[prop_name]
return value
def _generic_s(prop_name, self, value):
self._attributes[prop_name] = value
def _generic_d(prop_name, self):
del self._attributes[prop_name]
def _validate_action_group_metadata(action, found_group_metadata, fq_group_name):
valid_metadata = {
'extend_group': {
@ -118,83 +69,30 @@ def _validate_action_group_metadata(action, found_group_metadata, fq_group_name)
display.warning(" ".join(metadata_warnings))
class BaseMeta(type):
"""
Metaclass for the Base object, which is used to construct the class
attributes based on the FieldAttributes available.
"""
def __new__(cls, name, parents, dct):
def _create_attrs(src_dict, dst_dict):
'''
Helper method which creates the attributes based on those in the
source dictionary of attributes. This also populates the other
attributes used to keep track of these attributes and via the
getter/setter/deleter methods.
'''
keys = list(src_dict.keys())
for attr_name in keys:
value = src_dict[attr_name]
if isinstance(value, Attribute):
if attr_name.startswith('_'):
attr_name = attr_name[1:]
# here we selectively assign the getter based on a few
# things, such as whether we have a _get_attr_<name>
# method, or if the attribute is marked as not inheriting
# its value from a parent object
method = "_get_attr_%s" % attr_name
try:
if method in src_dict or method in dst_dict:
getter = partial(_generic_g_method, attr_name)
elif ('_get_parent_attribute' in dst_dict or '_get_parent_attribute' in src_dict) and value.inherit:
getter = partial(_generic_g_parent, attr_name)
else:
getter = partial(_generic_g, attr_name)
except AttributeError as e:
raise AnsibleParserError("Invalid playbook definition: %s" % to_native(e), orig_exc=e)
setter = partial(_generic_s, attr_name)
deleter = partial(_generic_d, attr_name)
dst_dict[attr_name] = property(getter, setter, deleter)
dst_dict['_valid_attrs'][attr_name] = value
dst_dict['_attributes'][attr_name] = Sentinel
dst_dict['_attr_defaults'][attr_name] = value.default
if value.alias is not None:
dst_dict[value.alias] = property(getter, setter, deleter)
dst_dict['_valid_attrs'][value.alias] = value
dst_dict['_alias_attrs'][value.alias] = attr_name
def _process_parents(parents, dst_dict):
'''
Helper method which creates attributes from all parent objects
recursively on through grandparent objects
'''
for parent in parents:
if hasattr(parent, '__dict__'):
_create_attrs(parent.__dict__, dst_dict)
new_dst_dict = parent.__dict__.copy()
new_dst_dict.update(dst_dict)
_process_parents(parent.__bases__, new_dst_dict)
# create some additional class attributes
dct['_attributes'] = {}
dct['_attr_defaults'] = {}
dct['_valid_attrs'] = {}
dct['_alias_attrs'] = {}
# now create the attributes based on the FieldAttributes
# available, including from parent (and grandparent) objects
_create_attrs(dct, dct)
_process_parents(parents, dct)
return super(BaseMeta, cls).__new__(cls, name, parents, dct)
class FieldAttributeBase(metaclass=BaseMeta):
# FIXME use @property and @classmethod together which is possible since Python 3.9
class _FABMeta(type):
@property
def fattributes(cls):
# FIXME is this worth caching?
fattributes = {}
for class_obj in reversed(cls.__mro__):
for name, attr in list(class_obj.__dict__.items()):
if not isinstance(attr, Attribute):
continue
fattributes[name] = attr
if attr.alias:
setattr(class_obj, attr.alias, attr)
fattributes[attr.alias] = attr
return fattributes
class FieldAttributeBase(metaclass=_FABMeta):
# FIXME use @property and @classmethod together which is possible since Python 3.9
@property
def fattributes(self):
return self.__class__.fattributes
def __init__(self):
@ -211,17 +109,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
# every object gets a random uuid:
self._uuid = get_unique_id()
# we create a copy of the attributes here due to the fact that
# it was initialized as a class param in the meta class, so we
# need a unique object here (all members contained within are
# unique already).
self._attributes = self.__class__._attributes.copy()
self._attr_defaults = self.__class__._attr_defaults.copy()
for key, value in self._attr_defaults.items():
if callable(value):
self._attr_defaults[key] = value()
# and init vars, avoid using defaults in field declaration as it lives across plays
# init vars, avoid using defaults in field declaration as it lives across plays
self.vars = dict()
@property
@ -273,17 +161,14 @@ class FieldAttributeBase(metaclass=BaseMeta):
# Walk all attributes in the class. We sort them based on their priority
# so that certain fields can be loaded before others, if they are dependent.
for name, attr in sorted(self._valid_attrs.items(), key=operator.itemgetter(1)):
for name, attr in sorted(self.fattributes.items(), key=operator.itemgetter(1)):
# copy the value over unless a _load_field method is defined
target_name = name
if name in self._alias_attrs:
target_name = self._alias_attrs[name]
if name in ds:
method = getattr(self, '_load_%s' % name, None)
if method:
self._attributes[target_name] = method(name, ds[name])
setattr(self, name, method(name, ds[name]))
else:
self._attributes[target_name] = ds[name]
setattr(self, name, ds[name])
# run early, non-critical validation
self.validate()
@ -316,7 +201,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
not map to attributes for this object.
'''
valid_attrs = frozenset(self._valid_attrs.keys())
valid_attrs = frozenset(self.fattributes)
for key in ds:
if key not in valid_attrs:
raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds)
@ -327,18 +212,14 @@ class FieldAttributeBase(metaclass=BaseMeta):
if not self._validated:
# walk all fields in the object
for (name, attribute) in self._valid_attrs.items():
if name in self._alias_attrs:
name = self._alias_attrs[name]
for (name, attribute) in self.fattributes.items():
# run validator only if present
method = getattr(self, '_validate_%s' % name, None)
if method:
method(attribute, name, getattr(self, name))
else:
# and make sure the attribute is of the type it should be
value = self._attributes[name]
value = getattr(self, name)
if value is not None:
if attribute.isa == 'string' and isinstance(value, (list, dict)):
raise AnsibleParserError(
@ -528,8 +409,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
parent attributes.
'''
if not self._squashed:
for name in self._valid_attrs.keys():
self._attributes[name] = getattr(self, name)
for name in self.fattributes:
setattr(self, name, getattr(self, name))
self._squashed = True
def copy(self):
@ -542,11 +423,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
except RuntimeError as e:
raise AnsibleError("Exceeded maximum object depth. This may have been caused by excessive role recursion", orig_exc=e)
for name in self._valid_attrs.keys():
if name in self._alias_attrs:
continue
new_me._attributes[name] = shallowcopy(self._attributes[name])
new_me._attr_defaults[name] = shallowcopy(self._attr_defaults[name])
for name in self.fattributes:
setattr(new_me, name, shallowcopy(getattr(self, f'_{name}', Sentinel)))
new_me._loader = self._loader
new_me._variable_manager = self._variable_manager
@ -621,8 +499,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
# save the omit value for later checking
omit_value = templar.available_variables.get('omit')
for (name, attribute) in self._valid_attrs.items():
for (name, attribute) in self.fattributes.items():
if attribute.static:
value = getattr(self, name)
@ -748,7 +625,7 @@ class FieldAttributeBase(metaclass=BaseMeta):
Dumps all attributes to a dictionary
'''
attrs = {}
for (name, attribute) in self._valid_attrs.items():
for (name, attribute) in self.fattributes.items():
attr = getattr(self, name)
if attribute.isa == 'class' and hasattr(attr, 'serialize'):
attrs[name] = attr.serialize()
@ -761,8 +638,8 @@ class FieldAttributeBase(metaclass=BaseMeta):
Loads attributes from a dictionary
'''
for (attr, value) in attrs.items():
if attr in self._valid_attrs:
attribute = self._valid_attrs[attr]
if attr in self.fattributes:
attribute = self.fattributes[attr]
if attribute.isa == 'class' and isinstance(value, dict):
obj = attribute.class_type()
obj.deserialize(value)
@ -806,14 +683,11 @@ class FieldAttributeBase(metaclass=BaseMeta):
if not isinstance(data, dict):
raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
for (name, attribute) in self._valid_attrs.items():
for (name, attribute) in self.fattributes.items():
if name in data:
setattr(self, name, data[name])
else:
if callable(attribute.default):
setattr(self, name, attribute.default())
else:
setattr(self, name, attribute.default)
setattr(self, name, attribute.default)
# restore the UUID field
setattr(self, '_uuid', data.get('uuid'))
@ -823,40 +697,40 @@ class FieldAttributeBase(metaclass=BaseMeta):
class Base(FieldAttributeBase):
_name = FieldAttribute(isa='string', default='', always_post_validate=True, inherit=False)
name = NonInheritableFieldAttribute(isa='string', default='', always_post_validate=True)
# connection/transport
_connection = FieldAttribute(isa='string', default=context.cliargs_deferred_get('connection'))
_port = FieldAttribute(isa='int')
_remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user'))
connection = ConnectionFieldAttribute(isa='string', default=context.cliargs_deferred_get('connection'))
port = FieldAttribute(isa='int')
remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user'))
# variables
_vars = FieldAttribute(isa='dict', priority=100, inherit=False, static=True)
vars = NonInheritableFieldAttribute(isa='dict', priority=100, static=True)
# module default params
_module_defaults = FieldAttribute(isa='list', extend=True, prepend=True)
module_defaults = FieldAttribute(isa='list', extend=True, prepend=True)
# flags and misc. settings
_environment = FieldAttribute(isa='list', extend=True, prepend=True)
_no_log = FieldAttribute(isa='bool')
_run_once = FieldAttribute(isa='bool')
_ignore_errors = FieldAttribute(isa='bool')
_ignore_unreachable = FieldAttribute(isa='bool')
_check_mode = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('check'))
_diff = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('diff'))
_any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL)
_throttle = FieldAttribute(isa='int', default=0)
_timeout = FieldAttribute(isa='int', default=C.TASK_TIMEOUT)
environment = FieldAttribute(isa='list', extend=True, prepend=True)
no_log = FieldAttribute(isa='bool')
run_once = FieldAttribute(isa='bool')
ignore_errors = FieldAttribute(isa='bool')
ignore_unreachable = FieldAttribute(isa='bool')
check_mode = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('check'))
diff = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('diff'))
any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL)
throttle = FieldAttribute(isa='int', default=0)
timeout = FieldAttribute(isa='int', default=C.TASK_TIMEOUT)
# explicitly invoke a debugger on tasks
_debugger = FieldAttribute(isa='string')
debugger = FieldAttribute(isa='string')
# Privilege escalation
_become = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('become'))
_become_method = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_method'))
_become_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_user'))
_become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
_become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
become = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('become'))
become_method = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_method'))
become_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_user'))
become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
# used to hold sudo/su stuff
DEPRECATED_ATTRIBUTES = [] # type: list[str]

@ -21,7 +21,7 @@ __metaclass__ = type
import ansible.constants as C
from ansible.errors import AnsibleParserError
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.conditional import Conditional
from ansible.playbook.collectionsearch import CollectionSearch
@ -34,18 +34,18 @@ from ansible.utils.sentinel import Sentinel
class Block(Base, Conditional, CollectionSearch, Taggable):
# main block fields containing the task lists
_block = FieldAttribute(isa='list', default=list, inherit=False)
_rescue = FieldAttribute(isa='list', default=list, inherit=False)
_always = FieldAttribute(isa='list', default=list, inherit=False)
block = NonInheritableFieldAttribute(isa='list', default=list)
rescue = NonInheritableFieldAttribute(isa='list', default=list)
always = NonInheritableFieldAttribute(isa='list', default=list)
# other fields for task compat
_notify = FieldAttribute(isa='list')
_delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool')
notify = FieldAttribute(isa='list')
delegate_to = FieldAttribute(isa='string')
delegate_facts = FieldAttribute(isa='bool')
# for future consideration? this would be functionally
# similar to the 'else' clause for exceptions
# _otherwise = FieldAttribute(isa='list')
# otherwise = FieldAttribute(isa='list')
def __init__(self, play=None, parent_block=None, role=None, task_include=None, use_handlers=False, implicit=False):
self._play = play
@ -230,7 +230,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
'''
data = dict()
for attr in self._valid_attrs:
for attr in self.fattributes:
if attr not in ('block', 'rescue', 'always'):
data[attr] = getattr(self, attr)
@ -256,7 +256,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
# we don't want the full set of attributes (the task lists), as that
# would lead to a serialize/deserialize loop
for attr in self._valid_attrs:
for attr in self.fattributes:
if attr in data and attr not in ('block', 'rescue', 'always'):
setattr(self, attr, data.get(attr))
@ -298,11 +298,10 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
'''
Generic logic to get the attribute or parent attribute for a block value.
'''
extend = self._valid_attrs[attr].extend
prepend = self._valid_attrs[attr].prepend
extend = self.fattributes.get(attr).extend
prepend = self.fattributes.get(attr).prepend
try:
value = self._attributes[attr]
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
@ -316,7 +315,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
if hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = _parent._attributes.get(attr, Sentinel)
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
@ -325,7 +324,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
pass
if self._role and (value is Sentinel or extend):
try:
parent_value = self._role._attributes.get(attr, Sentinel)
parent_value = getattr(self._role, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
@ -335,7 +334,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
if dep_chain and (value is Sentinel or extend):
dep_chain.reverse()
for dep in dep_chain:
dep_value = dep._attributes.get(attr, Sentinel)
dep_value = getattr(dep, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, dep_value, prepend)
else:
@ -347,7 +346,7 @@ class Block(Base, Conditional, CollectionSearch, Taggable):
pass
if self._play and (value is Sentinel or extend):
try:
play_value = self._play._attributes.get(attr, Sentinel)
play_value = getattr(self._play, f'_{attr}', Sentinel)
if play_value is not Sentinel:
if extend:
value = self._extend_value(value, play_value, prepend)

@ -36,13 +36,13 @@ def _ensure_default_collection(collection_list=None):
class CollectionSearch:
# this needs to be populated before we can resolve tasks/roles/etc
_collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection,
always_post_validate=True, static=True)
collections = FieldAttribute(isa='list', listof=string_types, priority=100, default=_ensure_default_collection,
always_post_validate=True, static=True)
def _load_collections(self, attr, ds):
# We are always a mixin with Base, so we can validate this untemplated
# field early on to guarantee we are dealing with a list.
ds = self.get_validated_value('collections', self._collections, ds, None)
ds = self.get_validated_value('collections', self.fattributes.get('collections'), ds, None)
# this will only be called if someone specified a value; call the shared value
_ensure_default_collection(collection_list=ds)

@ -46,7 +46,7 @@ class Conditional:
to be run conditionally when a condition is met or skipped.
'''
_when = FieldAttribute(isa='list', default=list, extend=True, prepend=True)
when = FieldAttribute(isa='list', default=list, extend=True, prepend=True)
def __init__(self, loader=None):
# when used directly, this class needs a loader, but we want to

@ -26,7 +26,7 @@ from ansible.module_utils.six import string_types
class Handler(Task):
_listen = FieldAttribute(isa='list', default=list, listof=string_types, static=True)
listen = FieldAttribute(isa='list', default=list, listof=string_types, static=True)
def __init__(self, block=None, role=None, task_include=None):
self.notified_hosts = []

@ -25,12 +25,12 @@ from ansible.playbook.base import FieldAttributeBase
class LoopControl(FieldAttributeBase):
_loop_var = FieldAttribute(isa='str', default='item')
_index_var = FieldAttribute(isa='str')
_label = FieldAttribute(isa='str')
_pause = FieldAttribute(isa='float', default=0)
_extended = FieldAttribute(isa='bool')
_extended_allitems = FieldAttribute(isa='bool', default=True)
loop_var = FieldAttribute(isa='str', default='item')
index_var = FieldAttribute(isa='str')
label = FieldAttribute(isa='str')
pause = FieldAttribute(isa='float', default=0)
extended = FieldAttribute(isa='bool')
extended_allitems = FieldAttribute(isa='bool', default=True)
def __init__(self):
super(LoopControl, self).__init__()

@ -54,35 +54,35 @@ class Play(Base, Taggable, CollectionSearch):
"""
# =================================================================================
_hosts = FieldAttribute(isa='list', required=True, listof=string_types, always_post_validate=True, priority=-1)
hosts = FieldAttribute(isa='list', required=True, listof=string_types, always_post_validate=True, priority=-2)
# Facts
_gather_facts = FieldAttribute(isa='bool', default=None, always_post_validate=True)
gather_facts = FieldAttribute(isa='bool', default=None, always_post_validate=True)
# defaults to be deprecated, should be 'None' in future
_gather_subset = FieldAttribute(isa='list', default=(lambda: C.DEFAULT_GATHER_SUBSET), listof=string_types, always_post_validate=True)
_gather_timeout = FieldAttribute(isa='int', default=C.DEFAULT_GATHER_TIMEOUT, always_post_validate=True)
_fact_path = FieldAttribute(isa='string', default=C.DEFAULT_FACT_PATH)
gather_subset = FieldAttribute(isa='list', default=(lambda: C.DEFAULT_GATHER_SUBSET), listof=string_types, always_post_validate=True)
gather_timeout = FieldAttribute(isa='int', default=C.DEFAULT_GATHER_TIMEOUT, always_post_validate=True)
fact_path = FieldAttribute(isa='string', default=C.DEFAULT_FACT_PATH)
# Variable Attributes
_vars_files = FieldAttribute(isa='list', default=list, priority=99)
_vars_prompt = FieldAttribute(isa='list', default=list, always_post_validate=False)
vars_files = FieldAttribute(isa='list', default=list, priority=99)
vars_prompt = FieldAttribute(isa='list', default=list, always_post_validate=False)
# Role Attributes
_roles = FieldAttribute(isa='list', default=list, priority=90)
roles = FieldAttribute(isa='list', default=list, priority=90)
# Block (Task) Lists Attributes
_handlers = FieldAttribute(isa='list', default=list)
_pre_tasks = FieldAttribute(isa='list', default=list)
_post_tasks = FieldAttribute(isa='list', default=list)
_tasks = FieldAttribute(isa='list', default=list)
handlers = FieldAttribute(isa='list', default=list, priority=-1)
pre_tasks = FieldAttribute(isa='list', default=list, priority=-1)
post_tasks = FieldAttribute(isa='list', default=list, priority=-1)
tasks = FieldAttribute(isa='list', default=list, priority=-1)
# Flag/Setting Attributes
_force_handlers = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('force_handlers'), always_post_validate=True)
_max_fail_percentage = FieldAttribute(isa='percent', always_post_validate=True)
_serial = FieldAttribute(isa='list', default=list, always_post_validate=True)
_strategy = FieldAttribute(isa='string', default=C.DEFAULT_STRATEGY, always_post_validate=True)
_order = FieldAttribute(isa='string', always_post_validate=True)
force_handlers = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('force_handlers'), always_post_validate=True)
max_fail_percentage = FieldAttribute(isa='percent', always_post_validate=True)
serial = FieldAttribute(isa='list', default=list, always_post_validate=True)
strategy = FieldAttribute(isa='string', default=C.DEFAULT_STRATEGY, always_post_validate=True)
order = FieldAttribute(isa='string', always_post_validate=True)
# =================================================================================

@ -77,46 +77,46 @@ class PlayContext(Base):
'''
# base
_module_compression = FieldAttribute(isa='string', default=C.DEFAULT_MODULE_COMPRESSION)
_shell = FieldAttribute(isa='string')
_executable = FieldAttribute(isa='string', default=C.DEFAULT_EXECUTABLE)
module_compression = FieldAttribute(isa='string', default=C.DEFAULT_MODULE_COMPRESSION)
shell = FieldAttribute(isa='string')
executable = FieldAttribute(isa='string', default=C.DEFAULT_EXECUTABLE)
# connection fields, some are inherited from Base:
# (connection, port, remote_user, environment, no_log)
_remote_addr = FieldAttribute(isa='string')
_password = FieldAttribute(isa='string')
_timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT)
_connection_user = FieldAttribute(isa='string')
_private_key_file = FieldAttribute(isa='string', default=C.DEFAULT_PRIVATE_KEY_FILE)
_pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_PIPELINING)
remote_addr = FieldAttribute(isa='string')
password = FieldAttribute(isa='string')
timeout = FieldAttribute(isa='int', default=C.DEFAULT_TIMEOUT)
connection_user = FieldAttribute(isa='string')
private_key_file = FieldAttribute(isa='string', default=C.DEFAULT_PRIVATE_KEY_FILE)
pipelining = FieldAttribute(isa='bool', default=C.ANSIBLE_PIPELINING)
# networking modules
_network_os = FieldAttribute(isa='string')
network_os = FieldAttribute(isa='string')
# docker FIXME: remove these
_docker_extra_args = FieldAttribute(isa='string')
docker_extra_args = FieldAttribute(isa='string')
# ???
_connection_lockfd = FieldAttribute(isa='int')
connection_lockfd = FieldAttribute(isa='int')
# privilege escalation fields
_become = FieldAttribute(isa='bool')
_become_method = FieldAttribute(isa='string')
_become_user = FieldAttribute(isa='string')
_become_pass = FieldAttribute(isa='string')
_become_exe = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_EXE)
_become_flags = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_FLAGS)
_prompt = FieldAttribute(isa='string')
become = FieldAttribute(isa='bool')
become_method = FieldAttribute(isa='string')
become_user = FieldAttribute(isa='string')
become_pass = FieldAttribute(isa='string')
become_exe = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_EXE)
become_flags = FieldAttribute(isa='string', default=C.DEFAULT_BECOME_FLAGS)
prompt = FieldAttribute(isa='string')
# general flags
_only_tags = FieldAttribute(isa='set', default=set)
_skip_tags = FieldAttribute(isa='set', default=set)
only_tags = FieldAttribute(isa='set', default=set)
skip_tags = FieldAttribute(isa='set', default=set)
_start_at_task = FieldAttribute(isa='string')
_step = FieldAttribute(isa='bool', default=False)
start_at_task = FieldAttribute(isa='string')
step = FieldAttribute(isa='bool', default=False)
# "PlayContext.force_handlers should not be used, the calling code should be using play itself instead"
_force_handlers = FieldAttribute(isa='bool', default=False)
force_handlers = FieldAttribute(isa='bool', default=False)
@property
def verbosity(self):
@ -353,21 +353,3 @@ class PlayContext(Base):
variables[var_opt] = var_val
except AttributeError:
continue
def _get_attr_connection(self):
''' connections are special, this takes care of responding correctly '''
conn_type = None
if self._attributes['connection'] == 'smart':
conn_type = 'ssh'
# see if SSH can support ControlPersist if not use paramiko
if not check_for_controlpersist('ssh') and paramiko is not None:
conn_type = "paramiko"
# if someone did `connection: persistent`, default it to using a persistent paramiko connection to avoid problems
elif self._attributes['connection'] == 'persistent' and paramiko is not None:
conn_type = 'paramiko'
if conn_type:
self.connection = conn_type
return self._attributes['connection']

@ -41,8 +41,8 @@ display = Display()
class PlaybookInclude(Base, Conditional, Taggable):
_import_playbook = FieldAttribute(isa='string')
_vars = FieldAttribute(isa='dict', default=dict)
import_playbook = FieldAttribute(isa='string')
vars_val = FieldAttribute(isa='dict', default=dict, alias='vars')
@staticmethod
def load(data, basedir, variable_manager=None, loader=None):
@ -120,7 +120,7 @@ class PlaybookInclude(Base, Conditional, Taggable):
# those attached to each block (if any)
if new_obj.when:
for task_block in (entry.pre_tasks + entry.roles + entry.tasks + entry.post_tasks):
task_block._attributes['when'] = new_obj.when[:] + task_block.when[:]
task_block._when = new_obj.when[:] + task_block.when[:]
return pb

@ -37,6 +37,7 @@ from ansible.playbook.taggable import Taggable
from ansible.plugins.loader import add_all_plugin_dirs
from ansible.utils.collection_loader import AnsibleCollectionConfig
from ansible.utils.path import is_subpath
from ansible.utils.sentinel import Sentinel
from ansible.utils.vars import combine_vars
__all__ = ['Role', 'hash_params']
@ -97,8 +98,8 @@ def hash_params(params):
class Role(Base, Conditional, Taggable, CollectionSearch):
_delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool')
delegate_to = FieldAttribute(isa='string')
delegate_facts = FieldAttribute(isa='bool')
def __init__(self, play=None, from_files=None, from_include=False, validate=True):
self._role_name = None
@ -198,15 +199,19 @@ class Role(Base, Conditional, Taggable, CollectionSearch):
self.add_parent(parent_role)
# copy over all field attributes from the RoleInclude
# update self._attributes directly, to avoid squashing
for (attr_name, dump) in self._valid_attrs.items():
# update self._attr directly, to avoid squashing
for attr_name in self.fattributes:
if attr_name in ('when', 'tags'):
self._attributes[attr_name] = self._extend_value(
self._attributes[attr_name],
role_include._attributes[attr_name],
setattr(
self,
f'_{attr_name}',
self._extend_value(
getattr(self, f'_{attr_name}', Sentinel),
getattr(role_include, f'_{attr_name}', Sentinel),
)
)
else:
self._attributes[attr_name] = role_include._attributes[attr_name]
setattr(self, f'_{attr_name}', getattr(role_include, f'_{attr_name}', Sentinel))
# vars and default vars are regular dictionaries
self._role_vars = self._load_role_yaml('vars', main=self._from_files.get('vars'), allow_dir=True)

@ -43,7 +43,7 @@ display = Display()
class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
_role = FieldAttribute(isa='string')
role = FieldAttribute(isa='string')
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
@ -210,7 +210,7 @@ class RoleDefinition(Base, Conditional, Taggable, CollectionSearch):
role_def = dict()
role_params = dict()
base_attribute_names = frozenset(self._valid_attrs.keys())
base_attribute_names = frozenset(self.fattributes)
for (key, value) in ds.items():
# use the list of FieldAttribute values to determine what is and is not
# an extra parameter for this role (or sub-class of this role)

@ -37,8 +37,8 @@ class RoleInclude(RoleDefinition):
is included for execution in a play.
"""
_delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool', default=False)
delegate_to = FieldAttribute(isa='string')
delegate_facts = FieldAttribute(isa='bool', default=False)
def __init__(self, play=None, role_basedir=None, variable_manager=None, loader=None, collection_list=None):
super(RoleInclude, self).__init__(play=play, role_basedir=role_basedir, variable_manager=variable_manager,

@ -39,10 +39,10 @@ class RoleMetadata(Base, CollectionSearch):
within each Role (meta/main.yml).
'''
_allow_duplicates = FieldAttribute(isa='bool', default=False)
_dependencies = FieldAttribute(isa='list', default=list)
_galaxy_info = FieldAttribute(isa='GalaxyInfo')
_argument_specs = FieldAttribute(isa='dict', default=dict)
allow_duplicates = FieldAttribute(isa='bool', default=False)
dependencies = FieldAttribute(isa='list', default=list)
galaxy_info = FieldAttribute(isa='GalaxyInfo')
argument_specs = FieldAttribute(isa='dict', default=dict)
def __init__(self, owner=None):
self._owner = owner

@ -52,9 +52,9 @@ class IncludeRole(TaskInclude):
# ATTRIBUTES
# private as this is a 'module options' vs a task property
_allow_duplicates = FieldAttribute(isa='bool', default=True, private=True)
_public = FieldAttribute(isa='bool', default=False, private=True)
_rolespec_validate = FieldAttribute(isa='bool', default=True)
allow_duplicates = FieldAttribute(isa='bool', default=True, private=True)
public = FieldAttribute(isa='bool', default=False, private=True)
rolespec_validate = FieldAttribute(isa='bool', default=True)
def __init__(self, block=None, role=None, task_include=None):

@ -28,7 +28,7 @@ from ansible.template import Templar
class Taggable:
untagged = frozenset(['untagged'])
_tags = FieldAttribute(isa='list', default=list, listof=(string_types, int), extend=True)
tags = FieldAttribute(isa='list', default=list, listof=(string_types, int), extend=True)
def _load_tags(self, attr, ds):
if isinstance(ds, list):

@ -26,7 +26,7 @@ from ansible.module_utils.six import string_types
from ansible.parsing.mod_args import ModuleArgsParser
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
from ansible.plugins.loader import lookup_loader
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.playbook.base import Base
from ansible.playbook.block import Block
from ansible.playbook.collectionsearch import CollectionSearch
@ -63,28 +63,28 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
# might be possible to define others
# NOTE: ONLY set defaults on task attributes that are not inheritable,
# inheritance is only triggered if the 'current value' is None,
# inheritance is only triggered if the 'current value' is Sentinel,
# default can be set at play/top level object and inheritance will take it's course.
_args = FieldAttribute(isa='dict', default=dict)
_action = FieldAttribute(isa='string')
_async_val = FieldAttribute(isa='int', default=0, alias='async')
_changed_when = FieldAttribute(isa='list', default=list)
_delay = FieldAttribute(isa='int', default=5)
_delegate_to = FieldAttribute(isa='string')
_delegate_facts = FieldAttribute(isa='bool')
_failed_when = FieldAttribute(isa='list', default=list)
_loop = FieldAttribute()
_loop_control = FieldAttribute(isa='class', class_type=LoopControl, inherit=False)
_notify = FieldAttribute(isa='list')
_poll = FieldAttribute(isa='int', default=C.DEFAULT_POLL_INTERVAL)
_register = FieldAttribute(isa='string', static=True)
_retries = FieldAttribute(isa='int', default=3)
_until = FieldAttribute(isa='list', default=list)
args = FieldAttribute(isa='dict', default=dict)
action = FieldAttribute(isa='string')
async_val = FieldAttribute(isa='int', default=0, alias='async')
changed_when = FieldAttribute(isa='list', default=list)
delay = FieldAttribute(isa='int', default=5)
delegate_to = FieldAttribute(isa='string')
delegate_facts = FieldAttribute(isa='bool')
failed_when = FieldAttribute(isa='list', default=list)
loop = FieldAttribute()
loop_control = NonInheritableFieldAttribute(isa='class', class_type=LoopControl)
notify = FieldAttribute(isa='list')
poll = FieldAttribute(isa='int', default=C.DEFAULT_POLL_INTERVAL)
register = FieldAttribute(isa='string', static=True)
retries = FieldAttribute(isa='int', default=3)
until = FieldAttribute(isa='list', default=list)
# deprecated, used to be loop and loop_args but loop has been repurposed
_loop_with = FieldAttribute(isa='string', private=True, inherit=False)
loop_with = NonInheritableFieldAttribute(isa='string', private=True)
def __init__(self, block=None, role=None, task_include=None):
''' constructors a task, without the Task.load classmethod, it will be pretty blank '''
@ -182,7 +182,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
else:
# Validate this untemplated field early on to guarantee we are dealing with a list.
# This is also done in CollectionSearch._load_collections() but this runs before that call.
collections_list = self.get_validated_value('collections', self._collections, collections_list, None)
collections_list = self.get_validated_value('collections', self.fattributes.get('collections'), collections_list, None)
if default_collection and not self._role: # FIXME: and not a collections role
if collections_list:
@ -460,11 +460,10 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
'''
Generic logic to get the attribute or parent attribute for a task value.
'''
extend = self._valid_attrs[attr].extend
prepend = self._valid_attrs[attr].prepend
extend = self.fattributes.get(attr).extend
prepend = self.fattributes.get(attr).prepend
try:
value = self._attributes[attr]
value = getattr(self, f'_{attr}', Sentinel)
# If parent is static, we can grab attrs from the parent
# otherwise, defer to the grandparent
if getattr(self._parent, 'statically_loaded', True):
@ -478,7 +477,7 @@ class Task(Base, Conditional, Taggable, CollectionSearch):
if attr != 'vars' and hasattr(_parent, '_get_parent_attribute'):
parent_value = _parent._get_parent_attribute(attr)
else:
parent_value = _parent._attributes.get(attr, Sentinel)
parent_value = getattr(_parent, f'_{attr}', Sentinel)
if extend:
value = self._extend_value(value, parent_value, prepend)

@ -21,7 +21,6 @@ __metaclass__ = type
import ansible.constants as C
from ansible.errors import AnsibleParserError
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.block import Block
from ansible.playbook.task import Task
from ansible.utils.display import Display

@ -677,7 +677,7 @@ class StrategyBase:
continue
listeners = listening_handler.get_validated_value(
'listen', listening_handler._valid_attrs['listen'], listeners, handler_templar
'listen', listening_handler.fattributes.get('listen'), listeners, handler_templar
)
if handler_name not in listeners:
continue

@ -39,14 +39,12 @@ def get_reserved_names(include_private=True):
class_list = [Play, Role, Block, Task]
for aclass in class_list:
aobj = aclass()
# build ordered list to loop over and dict with attributes
for attribute in aobj.__dict__['_attributes']:
if 'private' in attribute:
private.add(attribute)
for name, attr in aclass.fattributes.items():
if attr.private:
private.add(name)
else:
public.add(attribute)
public.add(name)
# local_action is implicit with action
if 'action' in public:

@ -23,10 +23,11 @@ from units.compat import unittest
from ansible.errors import AnsibleParserError
from ansible.module_utils.six import string_types
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.attribute import FieldAttribute, NonInheritableFieldAttribute
from ansible.template import Templar
from ansible.playbook import base
from ansible.utils.unsafe_proxy import AnsibleUnsafeBytes, AnsibleUnsafeText
from ansible.utils.sentinel import Sentinel
from units.mock.loader import DictDataLoader
@ -77,7 +78,7 @@ class TestBase(unittest.TestCase):
def _assert_copy(self, orig, copy):
self.assertIsInstance(copy, self.ClassUnderTest)
self.assertIsInstance(copy, base.Base)
self.assertEqual(len(orig._valid_attrs), len(copy._valid_attrs))
self.assertEqual(len(orig.fattributes), len(copy.fattributes))
sentinel = 'Empty DS'
self.assertEqual(getattr(orig, '_ds', sentinel), getattr(copy, '_ds', sentinel))
@ -107,8 +108,8 @@ class TestBase(unittest.TestCase):
d = self.ClassUnderTest()
d.deserialize(data)
self.assertIn('run_once', d._attributes)
self.assertIn('check_mode', d._attributes)
self.assertIn('_run_once', d.__dict__)
self.assertIn('_check_mode', d.__dict__)
data = {'no_log': False,
'remote_user': None,
@ -122,9 +123,9 @@ class TestBase(unittest.TestCase):
d = self.ClassUnderTest()
d.deserialize(data)
self.assertNotIn('a_sentinel_with_an_unlikely_name', d._attributes)
self.assertIn('run_once', d._attributes)
self.assertIn('check_mode', d._attributes)
self.assertNotIn('_a_sentinel_with_an_unlikely_name', d.__dict__)
self.assertIn('_run_once', d.__dict__)
self.assertIn('_check_mode', d.__dict__)
def test_serialize_then_deserialize(self):
ds = {'environment': [],
@ -165,7 +166,7 @@ class TestBase(unittest.TestCase):
# environment is supposed to be a list. This
# seems like it shouldn't work?
ret = self.b.load_data(ds)
self.assertEqual(True, ret._attributes['environment'])
self.assertEqual(True, ret._environment)
def test_post_validate(self):
ds = {'environment': [],
@ -312,7 +313,7 @@ class ExampleException(Exception):
# naming fails me...
class ExampleParentBaseSubClass(base.Base):
_test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing')
test_attr_parent_string = FieldAttribute(isa='string', default='A string attr for a class that may be a parent for testing')
def __init__(self):
@ -324,9 +325,8 @@ class ExampleParentBaseSubClass(base.Base):
class ExampleSubClass(base.Base):
_test_attr_blip = FieldAttribute(isa='string', default='example sub class test_attr_blip',
inherit=False,
always_post_validate=True)
test_attr_blip = NonInheritableFieldAttribute(isa='string', default='example sub class test_attr_blip',
always_post_validate=True)
def __init__(self):
super(ExampleSubClass, self).__init__()
@ -339,39 +339,39 @@ class ExampleSubClass(base.Base):
class BaseSubClass(base.Base):
_name = FieldAttribute(isa='string', default='', always_post_validate=True)
_test_attr_bool = FieldAttribute(isa='bool', always_post_validate=True)
_test_attr_int = FieldAttribute(isa='int', always_post_validate=True)
_test_attr_float = FieldAttribute(isa='float', default=3.14159, always_post_validate=True)
_test_attr_list = FieldAttribute(isa='list', listof=string_types, always_post_validate=True)
_test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True)
_test_attr_list_required = FieldAttribute(isa='list', listof=string_types, required=True,
default=list, always_post_validate=True)
_test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value')
_test_attr_string_required = FieldAttribute(isa='string', required=True,
default='the_test_attr_string_default_value')
_test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True)
_test_attr_set = FieldAttribute(isa='set', default=set, always_post_validate=True)
_test_attr_dict = FieldAttribute(isa='dict', default=lambda: {'a_key': 'a_value'}, always_post_validate=True)
_test_attr_class = FieldAttribute(isa='class', class_type=ExampleSubClass)
_test_attr_class_post_validate = FieldAttribute(isa='class', class_type=ExampleSubClass,
always_post_validate=True)
_test_attr_unknown_isa = FieldAttribute(isa='not_a_real_isa', always_post_validate=True)
_test_attr_example = FieldAttribute(isa='string', default='the_default',
always_post_validate=True)
_test_attr_none = FieldAttribute(isa='string', always_post_validate=True)
_test_attr_preprocess = FieldAttribute(isa='string', default='the default for preprocess')
_test_attr_method = FieldAttribute(isa='string', default='some attr with a getter',
name = FieldAttribute(isa='string', default='', always_post_validate=True)
test_attr_bool = FieldAttribute(isa='bool', always_post_validate=True)
test_attr_int = FieldAttribute(isa='int', always_post_validate=True)
test_attr_float = FieldAttribute(isa='float', default=3.14159, always_post_validate=True)
test_attr_list = FieldAttribute(isa='list', listof=string_types, always_post_validate=True)
test_attr_list_no_listof = FieldAttribute(isa='list', always_post_validate=True)
test_attr_list_required = FieldAttribute(isa='list', listof=string_types, required=True,
default=list, always_post_validate=True)
test_attr_string = FieldAttribute(isa='string', default='the_test_attr_string_default_value')
test_attr_string_required = FieldAttribute(isa='string', required=True,
default='the_test_attr_string_default_value')
test_attr_percent = FieldAttribute(isa='percent', always_post_validate=True)
test_attr_set = FieldAttribute(isa='set', default=set, always_post_validate=True)
test_attr_dict = FieldAttribute(isa='dict', default=lambda: {'a_key': 'a_value'}, always_post_validate=True)
test_attr_class = FieldAttribute(isa='class', class_type=ExampleSubClass)
test_attr_class_post_validate = FieldAttribute(isa='class', class_type=ExampleSubClass,
always_post_validate=True)
test_attr_unknown_isa = FieldAttribute(isa='not_a_real_isa', always_post_validate=True)
test_attr_example = FieldAttribute(isa='string', default='the_default',
always_post_validate=True)
_test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter',
always_post_validate=True)
test_attr_none = FieldAttribute(isa='string', always_post_validate=True)
test_attr_preprocess = FieldAttribute(isa='string', default='the default for preprocess')
test_attr_method = FieldAttribute(isa='string', default='some attr with a getter',
always_post_validate=True)
test_attr_method_missing = FieldAttribute(isa='string', default='some attr with a missing getter',
always_post_validate=True)
def _get_attr_test_attr_method(self):
return 'foo bar'
def _validate_test_attr_example(self, attr, name, value):
if not isinstance(value, str):
raise ExampleException('_test_attr_example is not a string: %s type=%s' % (value, type(value)))
raise ExampleException('test_attr_example is not a string: %s type=%s' % (value, type(value)))
def _post_validate_test_attr_example(self, attr, value, templar):
after_template_value = templar.template(value)
@ -380,21 +380,6 @@ class BaseSubClass(base.Base):
def _post_validate_test_attr_none(self, attr, value, templar):
return None
def _get_parent_attribute(self, attr, extend=False, prepend=False):
value = None
try:
value = self._attributes[attr]
if self._parent and (value is None or extend):
parent_value = getattr(self._parent, attr, None)
if extend:
value = self._extend_value(value, parent_value, prepend)
else:
value = parent_value
except KeyError:
pass
return value
# terrible name, but it is a TestBase subclass for testing subclasses of Base
class TestBaseSubClass(TestBase):
@ -420,7 +405,7 @@ class TestBaseSubClass(TestBase):
ds = {'test_attr_int': MOST_RANDOM_NUMBER}
bsc = self._base_validate(ds)
del bsc.test_attr_int
self.assertNotIn('test_attr_int', bsc._attributes)
self.assertNotIn('_test_attr_int', bsc.__dict__)
def test_attr_float(self):
roughly_pi = 4.0
@ -569,18 +554,18 @@ class TestBaseSubClass(TestBase):
string_list = ['foo', 'bar']
ds = {'test_attr_list': string_list}
bsc = self._base_validate(ds)
self.assertEqual(string_list, bsc._attributes['test_attr_list'])
self.assertEqual(string_list, bsc._test_attr_list)
def test_attr_list_none(self):
ds = {'test_attr_list': None}
bsc = self._base_validate(ds)
self.assertEqual(None, bsc._attributes['test_attr_list'])
self.assertEqual(None, bsc._test_attr_list)
def test_attr_list_no_listof(self):
test_list = ['foo', 'bar', 123]
ds = {'test_attr_list_no_listof': test_list}
bsc = self._base_validate(ds)
self.assertEqual(test_list, bsc._attributes['test_attr_list_no_listof'])
self.assertEqual(test_list, bsc._test_attr_list_no_listof)
def test_attr_list_required(self):
string_list = ['foo', 'bar']
@ -590,7 +575,7 @@ class TestBaseSubClass(TestBase):
fake_loader = DictDataLoader({})
templar = Templar(loader=fake_loader)
bsc.post_validate(templar)
self.assertEqual(string_list, bsc._attributes['test_attr_list_required'])
self.assertEqual(string_list, bsc._test_attr_list_required)
def test_attr_list_required_empty_string(self):
string_list = [""]

@ -44,7 +44,7 @@ class MixinForMocks(object):
self.mock_play = MagicMock(name='MockPlay')
self.mock_play._attributes = []
self.mock_play.collections = None
self.mock_play._collections = None
self.mock_iterator = MagicMock(name='MockIterator')
self.mock_iterator._play = self.mock_play

Loading…
Cancel
Save