Remove (de)serialize methods (#85724)

* move task specific from_attrs to Task
* Keep deserialize on PC, add tests
pull/85824/head
Martin Krizek 3 months ago committed by GitHub
parent 36f00cdf1a
commit 065f202d30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -243,7 +243,7 @@ def main(args=None):
options = pickle.loads(opts_data, encoding='bytes') options = pickle.loads(opts_data, encoding='bytes')
play_context = PlayContext() play_context = PlayContext()
play_context.deserialize(pc_data) play_context.from_attrs(pc_data)
except Exception as e: except Exception as e:
rc = 1 rc = 1

@ -1221,7 +1221,7 @@ def start_connection(play_context, options, task_uuid):
) )
write_to_stream(p.stdin, options) write_to_stream(p.stdin, options)
write_to_stream(p.stdin, play_context.serialize()) write_to_stream(p.stdin, play_context.dump_attrs())
(stdout, stderr) = p.communicate() (stdout, stderr) = p.communicate()

@ -659,8 +659,8 @@ class FieldAttributeBase:
attrs = {} attrs = {}
for (name, attribute) in self.fattributes.items(): for (name, attribute) in self.fattributes.items():
attr = getattr(self, name) attr = getattr(self, name)
if attribute.isa == 'class' and hasattr(attr, 'serialize'): if attribute.isa == 'class':
attrs[name] = attr.serialize() attrs[name] = attr.dump_attrs()
else: else:
attrs[name] = attr attrs[name] = attr
return attrs return attrs
@ -674,60 +674,13 @@ class FieldAttributeBase:
attribute = self.fattributes[attr] attribute = self.fattributes[attr]
if attribute.isa == 'class' and isinstance(value, dict): if attribute.isa == 'class' and isinstance(value, dict):
obj = attribute.class_type() obj = attribute.class_type()
obj.deserialize(value) obj.from_attrs(value)
setattr(self, attr, obj) setattr(self, attr, obj)
else: else:
setattr(self, attr, value) setattr(self, attr, value)
else: else:
setattr(self, attr, value) # overridden dump_attrs in derived types may dump attributes which are not field attributes setattr(self, attr, value) # overridden dump_attrs in derived types may dump attributes which are not field attributes
# from_attrs is only used to create a finalized task
# from attrs from the Worker/TaskExecutor
# Those attrs are finalized and squashed in the TE
# and controller side use needs to reflect that
self._finalized = True
self._squashed = True
def serialize(self):
"""
Serializes the object derived from the base object into
a dictionary of values. This only serializes the field
attributes for the object, so this may need to be overridden
for any classes which wish to add additional items not stored
as field attributes.
"""
repr = self.dump_attrs()
# serialize the uuid field
repr['uuid'] = self._uuid
repr['finalized'] = self._finalized
repr['squashed'] = self._squashed
return repr
def deserialize(self, data):
"""
Given a dictionary of values, load up the field attributes for
this object. As with serialize(), if there are any non-field
attribute data members, this method will need to be overridden
and extended.
"""
if not isinstance(data, dict):
raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
for (name, attribute) in self.fattributes.items():
if name in data:
setattr(self, name, data[name])
else:
self.set_to_context(name)
# restore the UUID field
setattr(self, '_uuid', data.get('uuid'))
self._finalized = data.get('finalized', False)
self._squashed = data.get('squashed', False)
class Base(FieldAttributeBase): class Base(FieldAttributeBase):

@ -27,7 +27,6 @@ from ansible.playbook.collectionsearch import CollectionSearch
from ansible.playbook.delegatable import Delegatable from ansible.playbook.delegatable import Delegatable
from ansible.playbook.helpers import load_list_of_tasks from ansible.playbook.helpers import load_list_of_tasks
from ansible.playbook.notifiable import Notifiable from ansible.playbook.notifiable import Notifiable
from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
@ -220,65 +219,6 @@ class Block(Base, Conditional, CollectionSearch, Taggable, Notifiable, Delegatab
new_me.validate() new_me.validate()
return new_me return new_me
def serialize(self):
"""
Override of the default serialize method, since when we're serializing
a task we don't want to include the attribute list of tasks.
"""
data = dict()
for attr in self.fattributes:
if attr not in ('block', 'rescue', 'always'):
data[attr] = getattr(self, attr)
data['dep_chain'] = self.get_dep_chain()
if self._role is not None:
data['role'] = self._role.serialize()
if self._parent is not None:
data['parent'] = self._parent.copy(exclude_tasks=True).serialize()
data['parent_type'] = self._parent.__class__.__name__
return data
def deserialize(self, data):
"""
Override of the default deserialize method, to match the above overridden
serialize method
"""
# import is here to avoid import loops
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.handler_task_include import HandlerTaskInclude
# we don't want the full set of attributes (the task lists), as that
# would lead to a serialize/deserialize loop
for attr in self.fattributes:
if attr in data and attr not in ('block', 'rescue', 'always'):
setattr(self, attr, data.get(attr))
self._dep_chain = data.get('dep_chain', None)
# if there was a serialized role, unpack it too
role_data = data.get('role')
if role_data:
r = Role()
r.deserialize(role_data)
self._role = r
parent_data = data.get('parent')
if parent_data:
parent_type = data.get('parent_type')
if parent_type == 'Block':
p = Block()
elif parent_type == 'TaskInclude':
p = TaskInclude()
elif parent_type == 'HandlerTaskInclude':
p = HandlerTaskInclude()
p.deserialize(parent_data)
self._parent = p
self._dep_chain = self._parent.get_dep_chain()
def set_loader(self, loader): def set_loader(self, loader):
self._loader = loader self._loader = loader
if self._parent: if self._parent:

@ -71,8 +71,3 @@ class Handler(Task):
def is_host_notified(self, host): def is_host_notified(self, host):
return host in self.notified_hosts return host in self.notified_hosts
def serialize(self):
result = super(Handler, self).serialize()
result['is_handler'] = True
return result

@ -399,36 +399,6 @@ class Play(Base, Taggable, CollectionSearch):
tasklist.append(task) tasklist.append(task)
return tasklist return tasklist
def serialize(self):
data = super(Play, self).serialize()
roles = []
for role in self.get_roles():
roles.append(role.serialize())
data['roles'] = roles
data['included_path'] = self._included_path
data['action_groups'] = self._action_groups
data['group_actions'] = self._group_actions
return data
def deserialize(self, data):
super(Play, self).deserialize(data)
self._included_path = data.get('included_path', None)
self._action_groups = data.get('action_groups', {})
self._group_actions = data.get('group_actions', {})
if 'roles' in data:
role_data = data.get('roles', [])
roles = []
for role in role_data:
r = Role()
r.deserialize(role)
roles.append(r)
setattr(self, 'roles', roles)
del data['roles']
def copy(self): def copy(self):
new_me = super(Play, self).copy() new_me = super(Play, self).copy()
new_me.role_cache = self.role_cache.copy() new_me.role_cache = self.role_cache.copy()

@ -325,3 +325,7 @@ class PlayContext(Base):
variables[var_opt] = var_val variables[var_opt] = var_val
except AttributeError: except AttributeError:
continue continue
def deserialize(self, data):
"""Do not use this method. Backward compatibility for network connections plugins that rely on it."""
self.from_attrs(data)

@ -655,65 +655,6 @@ class Role(Base, Conditional, Taggable, CollectionSearch, Delegatable):
return block_list return block_list
def serialize(self, include_deps=True):
res = super(Role, self).serialize()
res['_role_name'] = self._role_name
res['_role_path'] = self._role_path
res['_role_vars'] = self._role_vars
res['_role_params'] = self._role_params
res['_default_vars'] = self._default_vars
res['_had_task_run'] = self._had_task_run.copy()
res['_completed'] = self._completed.copy()
res['_metadata'] = self._metadata.serialize()
if include_deps:
deps = []
for role in self.get_direct_dependencies():
deps.append(role.serialize())
res['_dependencies'] = deps
parents = []
for parent in self._parents:
parents.append(parent.serialize(include_deps=False))
res['_parents'] = parents
return res
def deserialize(self, data, include_deps=True):
self._role_name = data.get('_role_name', '')
self._role_path = data.get('_role_path', '')
self._role_vars = data.get('_role_vars', dict())
self._role_params = data.get('_role_params', dict())
self._default_vars = data.get('_default_vars', dict())
self._had_task_run = data.get('_had_task_run', dict())
self._completed = data.get('_completed', dict())
if include_deps:
deps = []
for dep in data.get('_dependencies', []):
r = Role()
r.deserialize(dep)
deps.append(r)
setattr(self, '_dependencies', deps)
parent_data = data.get('_parents', [])
parents = []
for parent in parent_data:
r = Role()
r.deserialize(parent, include_deps=False)
parents.append(r)
setattr(self, '_parents', parents)
metadata_data = data.get('_metadata')
if metadata_data:
m = RoleMetadata()
m.deserialize(metadata_data)
self._metadata = m
super(Role, self).deserialize(data)
def set_loader(self, loader): def set_loader(self, loader):
self._loader = loader self._loader = loader
for parent in self._parents: for parent in self._parents:

@ -105,13 +105,3 @@ class RoleMetadata(Base, CollectionSearch):
collection_search_list=collection_search_list) collection_search_list=collection_search_list)
except AssertionError as ex: except AssertionError as ex:
raise AnsibleParserError("A malformed list of role dependencies was encountered.", obj=self._ds) from ex raise AnsibleParserError("A malformed list of role dependencies was encountered.", obj=self._ds) from ex
def serialize(self):
return dict(
allow_duplicates=self._allow_duplicates,
dependencies=self._dependencies
)
def deserialize(self, data):
setattr(self, 'allow_duplicates', data.get('allow_duplicates', False))
setattr(self, 'dependencies', data.get('dependencies', []))

@ -36,7 +36,6 @@ from ansible.playbook.conditional import Conditional
from ansible.playbook.delegatable import Delegatable from ansible.playbook.delegatable import Delegatable
from ansible.playbook.loop_control import LoopControl from ansible.playbook.loop_control import LoopControl
from ansible.playbook.notifiable import Notifiable from ansible.playbook.notifiable import Notifiable
from ansible.playbook.role import Role
from ansible.playbook.taggable import Taggable from ansible.playbook.taggable import Taggable
from ansible._internal import _task from ansible._internal import _task
from ansible._internal._templating import _marker_behaviors from ansible._internal._templating import _marker_behaviors
@ -504,53 +503,6 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
return new_me return new_me
def serialize(self):
data = super(Task, self).serialize()
if not self._squashed and not self._finalized:
if self._parent:
data['parent'] = self._parent.serialize()
data['parent_type'] = self._parent.__class__.__name__
if self._role:
data['role'] = self._role.serialize()
data['implicit'] = self.implicit
data['_resolved_action'] = self._resolved_action
return data
def deserialize(self, data):
# import is here to avoid import loops
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.handler_task_include import HandlerTaskInclude
parent_data = data.get('parent', None)
if parent_data:
parent_type = data.get('parent_type')
if parent_type == 'Block':
p = Block()
elif parent_type == 'TaskInclude':
p = TaskInclude()
elif parent_type == 'HandlerTaskInclude':
p = HandlerTaskInclude()
p.deserialize(parent_data)
self._parent = p
del data['parent']
role_data = data.get('role')
if role_data:
r = Role()
r.deserialize(role_data)
self._role = r
del data['role']
self.implicit = data.get('implicit', False)
self._resolved_action = data.get('_resolved_action')
super(Task, self).deserialize(data)
def set_loader(self, loader): def set_loader(self, loader):
""" """
Sets the loader on this object and recursively on parent, child objects. Sets the loader on this object and recursively on parent, child objects.
@ -628,6 +580,16 @@ class Task(Base, Conditional, Taggable, CollectionSearch, Notifiable, Delegatabl
attrs.update(_resolved_action=self._resolved_action) attrs.update(_resolved_action=self._resolved_action)
return attrs return attrs
def from_attrs(self, attrs):
super().from_attrs(attrs)
# from_attrs is only used to create a finalized task
# from attrs from the Worker/TaskExecutor
# Those attrs are finalized and squashed in the TE
# and controller side use needs to reflect that
self._finalized = True
self._squashed = True
def _resolve_conditional( def _resolve_conditional(
self, self,
conditional: list[str | bool], conditional: list[str | bool],

@ -584,7 +584,7 @@ class StrategyBase:
self._variable_manager.set_nonpersistent_facts( self._variable_manager.set_nonpersistent_facts(
original_host.name, original_host.name,
dict( dict(
ansible_failed_task=original_task.serialize(), ansible_failed_task=original_task.dump_attrs(),
ansible_failed_result=task_result._return_data, ansible_failed_result=task_result._return_data,
), ),
) )

@ -0,0 +1,2 @@
shippable/posix/group5
context/controller

@ -0,0 +1,72 @@
from __future__ import annotations
DOCUMENTATION = """
options:
persistent_connect_timeout:
type: int
default: 30
ini:
- section: persistent_connection
key: connect_timeout
env:
- name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT
vars:
- name: ansible_connect_timeout
persistent_command_timeout:
type: int
default: 30
ini:
- section: persistent_connection
key: command_timeout
env:
- name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT
vars:
- name: ansible_command_timeout
persistent_log_messages:
type: boolean
ini:
- section: persistent_connection
key: log_messages
env:
- name: ANSIBLE_PERSISTENT_LOG_MESSAGES
vars:
- name: ansible_persistent_log_messages
"""
import json
import os
import pickle
from ansible.playbook.play_context import PlayContext
from ansible.plugins.connection import NetworkConnectionBase
class Connection(NetworkConnectionBase):
transport = 'persistent'
supports_persistence = True
def _connect(self):
self._connected = True
def update_play_context(self, pc_data):
"""
This is to ensure that the PlayContext.deserialize method remains functional,
preventing it from breaking the network connection plugins that rely on it.
See:
https://github.com/ansible-collections/ansible.netcommon/blob/50fafb6875bb2f57e932a7a50123513b48bd4fd5/plugins/connection/httpapi.py#L258
"""
pc = self._play_context = PlayContext()
pc.deserialize(
pickle.loads(
pc_data.encode(errors='surrogateescape')
)
)
def get_capabilities(self, *args, **kwargs):
return json.dumps({
'pid': os.getpid(),
'ppid': os.getppid(),
**self._play_context.dump_attrs()
})

@ -0,0 +1 @@
localhost0 ansible_connection=persistent ansible_python_interpreter={{ansible_playbook_python}}

@ -0,0 +1,16 @@
from __future__ import annotations
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.connection import Connection
def main():
module = AnsibleModule({})
connection = Connection(module._socket_path)
capabilities = module.from_json(connection.get_capabilities())
module.exit_json(**capabilities)
if __name__ == '__main__':
main()

@ -0,0 +1,5 @@
- hosts: all
gather_facts: false
tasks:
- persistent:
- persistent:

@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -eux
ansible-playbook -i inventory playbook.yml -v "$@"

@ -89,53 +89,6 @@ class TestBase(unittest.TestCase):
copy = b.copy() copy = b.copy()
self._assert_copy(b, copy) self._assert_copy(b, copy)
def test_serialize(self):
ds = {}
ds = {'environment': [],
'vars': self.assorted_vars
}
b = self._base_validate(ds)
ret = b.serialize()
self.assertIsInstance(ret, dict)
def test_deserialize(self):
data = {}
d = self.ClassUnderTest()
d.deserialize(data)
self.assertIn('_run_once', d.__dict__)
self.assertIn('_check_mode', d.__dict__)
data = {'no_log': False,
'remote_user': None,
'vars': self.assorted_vars,
'environment': [],
'run_once': False,
'connection': None,
'ignore_errors': False,
'port': 22,
'a_sentinel_with_an_unlikely_name': ['sure, a list']}
d = self.ClassUnderTest()
d.deserialize(data)
self.assertNotIn('_a_sentinel_with_an_unlikely_name', d.__dict__)
self.assertIn('_run_once', d.__dict__)
self.assertIn('_check_mode', d.__dict__)
def test_serialize_then_deserialize(self):
ds = {'environment': [],
'vars': self.assorted_vars}
b = self._base_validate(ds)
copy = b.copy()
ret = b.serialize()
b.deserialize(ret)
c = self.ClassUnderTest()
c.deserialize(ret)
# TODO: not a great test, but coverage...
self.maxDiff = None
self.assertDictEqual(b.serialize(), copy.serialize())
self.assertDictEqual(c.serialize(), copy.serialize())
def test_post_validate_empty(self): def test_post_validate_empty(self):
fake_loader = DictDataLoader({}) fake_loader = DictDataLoader({})
templar = TemplateEngine(loader=fake_loader) templar = TemplateEngine(loader=fake_loader)
@ -176,14 +129,6 @@ class TestBase(unittest.TestCase):
b = self._base_validate(ds) b = self._base_validate(ds)
self.assertEqual(b.port, 'some_port') self.assertEqual(b.port, 'some_port')
def test_squash(self):
data = self.b.serialize()
self.b.squash()
squashed_data = self.b.serialize()
# TODO: assert something
self.assertFalse(data['squashed'])
self.assertTrue(squashed_data['squashed'])
def test_vars(self): def test_vars(self):
# vars as a dict. # vars as a dict.
ds = {'environment': [], ds = {'environment': [],

Loading…
Cancel
Save