Replace get_persistent_connection_options in task_executor with get_options (#74446)

Replace get_persistent_connection_options with get_options
Remove special case for network sub_plugin in _set_plugin_options
Try to avoid mock connection pretending to be persistent
Rename variables->options to reflect what they actually are
Gather options for ssh_type_conn on network_cli
Drop reliance on sub_plugin["type"]
pull/78593/head
Kate Case 2 years ago committed by GitHub
parent 2fdaee143a
commit bf1ef5a1f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,3 @@
---
bugfixes:
- Fix for network_cli not getting all relevant connection options

@ -89,7 +89,7 @@ class ConnectionProcess(object):
self.connection = None self.connection = None
self._ansible_playbook_pid = ansible_playbook_pid self._ansible_playbook_pid = ansible_playbook_pid
def start(self, variables): def start(self, options):
messages = list() messages = list()
result = {} result = {}
@ -104,7 +104,7 @@ class ConnectionProcess(object):
self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null', self.connection = connection_loader.get(self.play_context.connection, self.play_context, '/dev/null',
task_uuid=self._task_uuid, ansible_playbook_pid=self._ansible_playbook_pid) task_uuid=self._task_uuid, ansible_playbook_pid=self._ansible_playbook_pid)
try: try:
self.connection.set_options(var_options=variables) self.connection.set_options(direct=options)
except ConnectionError as exc: except ConnectionError as exc:
messages.append(('debug', to_text(exc))) messages.append(('debug', to_text(exc)))
raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.') raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.')
@ -248,11 +248,11 @@ def main(args=None):
try: try:
# read the play context data via stdin, which means depickling it # read the play context data via stdin, which means depickling it
vars_data = read_stream(stdin) opts_data = read_stream(stdin)
init_data = read_stream(stdin) init_data = read_stream(stdin)
pc_data = pickle.loads(init_data, encoding='bytes') pc_data = pickle.loads(init_data, encoding='bytes')
variables = pickle.loads(vars_data, encoding='bytes') options = pickle.loads(opts_data, encoding='bytes')
play_context = PlayContext() play_context = PlayContext()
play_context.deserialize(pc_data) play_context.deserialize(pc_data)
@ -289,7 +289,7 @@ def main(args=None):
os.close(r) os.close(r)
wfd = os.fdopen(w, 'w') wfd = os.fdopen(w, 'w')
process = ConnectionProcess(wfd, play_context, socket_path, original_path, task_uuid, ansible_playbook_pid) process = ConnectionProcess(wfd, play_context, socket_path, original_path, task_uuid, ansible_playbook_pid)
process.start(variables) process.start(options)
except Exception: except Exception:
messages.append(('error', traceback.format_exc())) messages.append(('error', traceback.format_exc()))
rc = 1 rc = 1
@ -312,7 +312,7 @@ def main(args=None):
messages.append(('vvvv', 'found existing local domain socket, using it!')) messages.append(('vvvv', 'found existing local domain socket, using it!'))
conn = Connection(socket_path) conn = Connection(socket_path)
try: try:
conn.set_options(var_options=variables) conn.set_options(direct=options)
except ConnectionError as exc: except ConnectionError as exc:
messages.append(('debug', to_text(exc))) messages.append(('debug', to_text(exc)))
raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.') raise ConnectionError('Unable to decode JSON from response set_options. See the debug log for more information.')

@ -24,6 +24,7 @@ from ansible.module_utils._text import to_text, to_native
from ansible.module_utils.connection import write_to_file_descriptor from ansible.module_utils.connection import write_to_file_descriptor
from ansible.playbook.conditional import Conditional from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task from ansible.playbook.task import Task
from ansible.plugins import get_plugin_class
from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader from ansible.plugins.loader import become_loader, cliconf_loader, connection_loader, httpapi_loader, netconf_loader, terminal_loader
from ansible.template import Templar from ansible.template import Templar
from ansible.utils.collection_loader import AnsibleCollectionConfig, AnsibleCollectionRef from ansible.utils.collection_loader import AnsibleCollectionConfig, AnsibleCollectionRef
@ -584,6 +585,17 @@ class TaskExecutor:
# feed back into pc to ensure plugins not using get_option can get correct value # feed back into pc to ensure plugins not using get_option can get correct value
self._connection._play_context = self._play_context.set_task_and_variable_override(task=self._task, variables=vars_copy, templar=templar) self._connection._play_context = self._play_context.set_task_and_variable_override(task=self._task, variables=vars_copy, templar=templar)
# for persistent connections, initialize socket path and start connection manager
if any(((self._connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), self._connection.force_persistence)):
self._play_context.timeout = self._connection.get_option('persistent_command_timeout')
display.vvvv('attempting to start connection', host=self._play_context.remote_addr)
display.vvvv('using connection plugin %s' % self._connection.transport, host=self._play_context.remote_addr)
options = self._connection.get_options()
socket_path = start_connection(self._play_context, options, self._task._uuid)
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(self._connection, '_socket_path', socket_path)
# TODO: eventually remove this block as this should be a 'consequence' of 'forced_local' modules # TODO: eventually remove this block as this should be a 'consequence' of 'forced_local' modules
# special handling for python interpreter for network_os, default to ansible python unless overriden # special handling for python interpreter for network_os, default to ansible python unless overriden
if 'ansible_network_os' in cvars and 'ansible_python_interpreter' not in cvars: if 'ansible_network_os' in cvars and 'ansible_python_interpreter' not in cvars:
@ -992,32 +1004,8 @@ class TaskExecutor:
# Also backwards compat call for those still using play_context # Also backwards compat call for those still using play_context
self._play_context.set_attributes_from_plugin(connection) self._play_context.set_attributes_from_plugin(connection)
if any(((connection.supports_persistence and C.USE_PERSISTENT_CONNECTIONS), connection.force_persistence)):
self._play_context.timeout = connection.get_option('persistent_command_timeout')
display.vvvv('attempting to start connection', host=self._play_context.remote_addr)
display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr)
options = self._get_persistent_connection_options(connection, cvars, templar)
socket_path = start_connection(self._play_context, options, self._task._uuid)
display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr)
setattr(connection, '_socket_path', socket_path)
return connection return connection
def _get_persistent_connection_options(self, connection, final_vars, templar):
option_vars = C.config.get_plugin_vars('connection', connection._load_name)
plugin = connection._sub_plugin
if plugin.get('type'):
option_vars.extend(C.config.get_plugin_vars(plugin['type'], plugin['name']))
options = {}
for k in option_vars:
if k in final_vars:
options[k] = templar.template(final_vars[k])
return options
def _set_plugin_options(self, plugin_type, variables, templar, task_keys): def _set_plugin_options(self, plugin_type, variables, templar, task_keys):
try: try:
plugin = getattr(self._connection, '_%s' % plugin_type) plugin = getattr(self._connection, '_%s' % plugin_type)
@ -1025,6 +1013,10 @@ class TaskExecutor:
# Some plugins are assigned to private attrs, ``become`` is not # Some plugins are assigned to private attrs, ``become`` is not
plugin = getattr(self._connection, plugin_type) plugin = getattr(self._connection, plugin_type)
# network_cli's "real" connection plugin is not named connection
# to avoid the confusion of having connection.connection
if plugin_type == "ssh_type_conn":
plugin_type = "connection"
option_vars = C.config.get_plugin_vars(plugin_type, plugin._load_name) option_vars = C.config.get_plugin_vars(plugin_type, plugin._load_name)
options = {} options = {}
for k in option_vars: for k in option_vars:
@ -1094,6 +1086,15 @@ class TaskExecutor:
pass # some plugins don't support all base flags pass # some plugins don't support all base flags
self._play_context.prompt = self._connection.become.prompt self._play_context.prompt = self._connection.become.prompt
# deals with networking sub_plugins (network_cli/httpapi/netconf)
sub = getattr(self._connection, '_sub_plugin', None)
if sub is not None and sub.get('type') != 'external':
plugin_type = get_plugin_class(sub.get("obj"))
varnames.extend(self._set_plugin_options(plugin_type, variables, templar, task_keys))
sub_conn = getattr(self._connection, 'ssh_type_conn', None)
if sub_conn is not None:
varnames.extend(self._set_plugin_options("ssh_type_conn", variables, templar, task_keys))
return varnames return varnames
def _get_action_handler(self, connection, templar): def _get_action_handler(self, connection, templar):
@ -1156,7 +1157,7 @@ class TaskExecutor:
return handler, module return handler, module
def start_connection(play_context, variables, task_uuid): def start_connection(play_context, options, task_uuid):
''' '''
Starts the persistent connection Starts the persistent connection
''' '''
@ -1205,7 +1206,7 @@ def start_connection(play_context, variables, task_uuid):
try: try:
termios.tcsetattr(master, termios.TCSANOW, new) termios.tcsetattr(master, termios.TCSANOW, new)
write_to_file_descriptor(master, variables) write_to_file_descriptor(master, options)
write_to_file_descriptor(master, play_context.serialize()) write_to_file_descriptor(master, play_context.serialize())
(stdout, stderr) = p.communicate() (stdout, stderr) = p.communicate()

@ -334,6 +334,8 @@ class TestTaskExecutor(unittest.TestCase):
mock_play_context.update_vars.return_value = None mock_play_context.update_vars.return_value = None
mock_connection = MagicMock() mock_connection = MagicMock()
mock_connection.force_persistence = False
mock_connection.supports_persistence = False
mock_connection.set_host_overrides.return_value = None mock_connection.set_host_overrides.return_value = None
mock_connection._connect.return_value = None mock_connection._connect.return_value = None

Loading…
Cancel
Save