From 406b59aebac7244e492577409495572947c182e6 Mon Sep 17 00:00:00 2001 From: Nathaniel Case Date: Thu, 20 Sep 2018 09:56:43 -0400 Subject: [PATCH] Move persistent connections to only use registered variables (#45616) * Try to intuit proper plugins to send to ansible-connection * Move sub-plugins to init so that vars will be populated in executor * Fix connection unit tests --- lib/ansible/executor/task_executor.py | 24 +++++++++--- lib/ansible/plugins/connection/__init__.py | 31 +++++----------- lib/ansible/plugins/connection/httpapi.py | 37 +++++++++---------- lib/ansible/plugins/connection/napalm.py | 2 +- lib/ansible/plugins/connection/netconf.py | 19 +++++----- lib/ansible/plugins/connection/network_cli.py | 31 ++++++++-------- lib/ansible/plugins/netconf/__init__.py | 5 ++- .../plugins/connection/test_connection.py | 6 +++ test/units/plugins/connection/test_netconf.py | 15 ++------ .../plugins/connection/test_network_cli.py | 37 +++++++------------ 10 files changed, 98 insertions(+), 109 deletions(-) diff --git a/lib/ansible/executor/task_executor.py b/lib/ansible/executor/task_executor.py index 3874e2e9060..73c5ab076d7 100644 --- a/lib/ansible/executor/task_executor.py +++ b/lib/ansible/executor/task_executor.py @@ -842,17 +842,29 @@ class TaskExecutor: self._play_context.timeout = connection.get_option('persistent_command_timeout') display.vvvv('attempting to start connection', host=self._play_context.remote_addr) display.vvvv('using connection plugin %s' % connection.transport, host=self._play_context.remote_addr) - # We don't need to send the entire contents of variables to ansible-connection - filtered_vars = dict( - (key, value) for key, value in variables.items() - if key.startswith('ansible') and key != 'ansible_failed_task' - ) - socket_path = self._start_connection(filtered_vars) + + options = self._get_persistent_connection_options(connection, variables, templar) + socket_path = self._start_connection(options) display.vvvv('local domain socket path is %s' % socket_path, host=self._play_context.remote_addr) setattr(connection, '_socket_path', socket_path) return connection + def _get_persistent_connection_options(self, connection, variables, templar): + final_vars = combine_vars(variables, variables.get('ansible_delegated_vars', dict()).get(self._task.delegate_to, dict())) + + option_vars = C.config.get_plugin_vars('connection', connection._load_name) + for plugin in connection._sub_plugins: + if plugin['type'] != 'external': + option_vars.extend(C.config.get_plugin_vars(plugin['type'], plugin['name'])) + + options = {} + for k in option_vars: + if k in final_vars: + options[k] = templar.template(final_vars[k]) + + return options + def _set_connection_options(self, variables, templar): # Keep the pre-delegate values for these keys diff --git a/lib/ansible/plugins/connection/__init__.py b/lib/ansible/plugins/connection/__init__.py index b7cb3baf53f..4d4e3d5eca6 100644 --- a/lib/ansible/plugins/connection/__init__.py +++ b/lib/ansible/plugins/connection/__init__.py @@ -300,7 +300,7 @@ class NetworkConnectionBase(ConnectionBase): self._local = connection_loader.get('local', play_context, '/dev/null') self._local.set_options() - self._implementation_plugins = [] + self._sub_plugins = [] self._cached_variables = (None, None, None) # reconstruct the socket_path and set instance values accordingly @@ -312,16 +312,12 @@ class NetworkConnectionBase(ConnectionBase): return self.__dict__[name] except KeyError: if not name.startswith('_'): - for plugin in self._implementation_plugins: - method = getattr(plugin, name, None) + for plugin in self._sub_plugins: + method = getattr(plugin['obj'], name, None) if method is not None: return method raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - def _connect(self): - self.set_implementation_plugin_options(*self._cached_variables) - self._cached_variables = (None, None, None) - def exec_command(self, cmd, in_data=None, sudoable=True): return self._local.exec_command(cmd, in_data, sudoable) @@ -345,25 +341,16 @@ class NetworkConnectionBase(ConnectionBase): def close(self): if self._connected: self._connected = False - self._implementation_plugins = [] def set_options(self, task_keys=None, var_options=None, direct=None): super(NetworkConnectionBase, self).set_options(task_keys=task_keys, var_options=var_options, direct=direct) - if self._implementation_plugins: - self.set_implementation_plugin_options(task_keys, var_options, direct) - else: - self._cached_variables = (task_keys, var_options, direct) - - def set_implementation_plugin_options(self, task_keys=None, var_options=None, direct=None): - ''' - initialize implementation plugin options - ''' - for plugin in self._implementation_plugins: - try: - plugin.set_options(task_keys=task_keys, var_options=var_options, direct=direct) - except AttributeError: - pass + for plugin in self._sub_plugins: + if plugin['type'] != 'external': + try: + plugin['obj'].set_options(task_keys=task_keys, var_options=var_options, direct=direct) + except AttributeError: + pass def _update_connection_state(self): ''' diff --git a/lib/ansible/plugins/connection/httpapi.py b/lib/ansible/plugins/connection/httpapi.py index 7d56ddcf438..cd5bf324dcd 100644 --- a/lib/ansible/plugins/connection/httpapi.py +++ b/lib/ansible/plugins/connection/httpapi.py @@ -176,7 +176,22 @@ class Connection(NetworkConnectionBase): self._url = None self._auth = None - if not self._network_os: + if self._network_os: + + self.httpapi = httpapi_loader.get(self._network_os, self) + if self.httpapi: + self._sub_plugins.append({'type': 'httpapi', 'name': self._network_os, 'obj': self.httpapi}) + display.vvvv('loaded API plugin for network_os %s' % self._network_os) + else: + raise AnsibleConnectionFailure('unable to load API plugin for network_os %s' % self._network_os) + + self.cliconf = cliconf_loader.get(self._network_os, self) + if self.cliconf: + self._sub_plugins.append({'type': 'cliconf', 'name': self._network_os, 'obj': self.cliconf}) + display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os) + else: + display.vvvv('unable to load cliconf for network_os %s' % self._network_os) + else: raise AnsibleConnectionFailure( 'Unable to automatically determine host network os. Please ' 'manually configure ansible_network_os value for this host' @@ -211,24 +226,8 @@ class Connection(NetworkConnectionBase): port = self.get_option('port') or (443 if protocol == 'https' else 80) self._url = '%s://%s:%s' % (protocol, host, port) - httpapi = httpapi_loader.get(self._network_os, self) - if httpapi: - display.vvvv('loaded API plugin for network_os %s' % self._network_os, host=host) - self._implementation_plugins.append(httpapi) - else: - raise AnsibleConnectionFailure('unable to load API plugin for network_os %s' % self._network_os) - - cliconf = cliconf_loader.get(self._network_os, self) - if cliconf: - display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os, host=host) - self._implementation_plugins.append(cliconf) - else: - display.vvvv('unable to load cliconf for network_os %s' % self._network_os) - - super(Connection, self)._connect() - - httpapi.set_become(self._play_context) - httpapi.login(self.get_option('remote_user'), self.get_option('password')) + self.httpapi.set_become(self._play_context) + self.httpapi.login(self.get_option('remote_user'), self.get_option('password')) self._connected = True diff --git a/lib/ansible/plugins/connection/napalm.py b/lib/ansible/plugins/connection/napalm.py index 83d7237b0c5..242fb00403e 100644 --- a/lib/ansible/plugins/connection/napalm.py +++ b/lib/ansible/plugins/connection/napalm.py @@ -186,7 +186,7 @@ class Connection(NetworkConnectionBase): self.napalm.open() - self._implementation_plugins.append(self.napalm) + self._sub_plugins.append({'type': 'external', 'name': 'napalm', 'obj': self.napalm}) display.vvvv('created napalm device for network_os %s' % self._network_os, host=host) self._connected = True diff --git a/lib/ansible/plugins/connection/netconf.py b/lib/ansible/plugins/connection/netconf.py index b4cfd416200..096d9659bc8 100644 --- a/lib/ansible/plugins/connection/netconf.py +++ b/lib/ansible/plugins/connection/netconf.py @@ -217,6 +217,15 @@ class Connection(NetworkConnectionBase): super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) self._network_os = self._network_os or 'default' + + netconf = netconf_loader.get(self._network_os, self) + if netconf: + self._sub_plugins.append({'type': 'netconf', 'name': self._network_os, 'obj': netconf}) + display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True) + else: + netconf = netconf_loader.get("default", self) + self._sub_plugins.append({'type': 'netconf', 'name': 'default', 'obj': netconf}) + display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % self._network_os) display.display('network_os is set to %s' % self._network_os, log_only=True) self._manager = None @@ -246,8 +255,6 @@ class Connection(NetworkConnectionBase): return super(Connection, self).exec_command(cmd, in_data, sudoable) def _connect(self): - super(Connection, self)._connect() - display.display('ssh connection done, starting ncclient', log_only=True) allow_agent = True @@ -300,14 +307,6 @@ class Connection(NetworkConnectionBase): self._connected = True - netconf = netconf_loader.get(self._network_os, self) - if netconf: - display.display('loaded netconf plugin for network_os %s' % self._network_os, log_only=True) - else: - netconf = netconf_loader.get("default", self) - display.display('unable to load netconf plugin for network_os %s, falling back to default plugin' % self._network_os) - self._implementation_plugins.append(netconf) - super(Connection, self)._connect() return 0, to_bytes(self._manager.session_id, errors='surrogate_or_strict'), b'' diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index 1f7d58ace5b..c27c9ad73cb 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -208,6 +208,21 @@ class Connection(NetworkConnectionBase): if self._play_context.verbosity > 3: logging.getLogger('paramiko').setLevel(logging.DEBUG) + if self._network_os: + + self.cliconf = cliconf_loader.get(self._network_os, self) + if self.cliconf: + display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os) + self._sub_plugins.append({'type': 'cliconf', 'name': self._network_os, 'obj': self.cliconf}) + else: + display.vvvv('unable to load cliconf for network_os %s' % self._network_os) + else: + raise AnsibleConnectionFailure( + 'Unable to automatically determine host network os. Please ' + 'manually configure ansible_network_os value for this host' + ) + display.display('network_os is set to %s' % self._network_os, log_only=True) + def _get_log_channel(self): name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) name += "paramiko [%s]" % self._play_context.remote_addr @@ -270,13 +285,6 @@ class Connection(NetworkConnectionBase): Connects to the remote device and starts the terminal ''' if not self.connected: - if not self._network_os: - raise AnsibleConnectionFailure( - 'Unable to automatically determine host network os. Please ' - 'manually configure ansible_network_os value for this host' - ) - display.display('network_os is set to %s' % self._network_os, log_only=True) - self.paramiko_conn = connection_loader.get('paramiko', self._play_context, '/dev/null') self.paramiko_conn._set_log_channel(self._get_log_channel()) self.paramiko_conn.set_options(direct={'look_for_keys': not bool(self._play_context.password and not self._play_context.private_key_file)}) @@ -295,15 +303,6 @@ class Connection(NetworkConnectionBase): display.vvvv('loaded terminal plugin for network_os %s' % self._network_os, host=host) - self.cliconf = cliconf_loader.get(self._network_os, self) - if self.cliconf: - display.vvvv('loaded cliconf plugin for network_os %s' % self._network_os, host=host) - self._implementation_plugins.append(self.cliconf) - else: - display.vvvv('unable to load cliconf for network_os %s' % self._network_os) - - super(Connection, self)._connect() - self.receive(prompts=self._terminal.terminal_initial_prompt, answer=self._terminal.terminal_initial_answer, newline=self._terminal.terminal_inital_prompt_newline) diff --git a/lib/ansible/plugins/netconf/__init__.py b/lib/ansible/plugins/netconf/__init__.py index 1ccc767da18..2a8b20c5ff6 100644 --- a/lib/ansible/plugins/netconf/__init__.py +++ b/lib/ansible/plugins/netconf/__init__.py @@ -102,7 +102,10 @@ class NetconfBase(AnsiblePlugin): def __init__(self, connection): self._connection = connection - self.m = self._connection._manager + + @property + def m(self): + return self._connection._manager @ensure_connected def rpc(self, name): diff --git a/test/units/plugins/connection/test_connection.py b/test/units/plugins/connection/test_connection.py index a2f89655797..cef50807a5c 100644 --- a/test/units/plugins/connection/test_connection.py +++ b/test/units/plugins/connection/test_connection.py @@ -41,6 +41,7 @@ from ansible.plugins.connection.ssh import Connection as SSHConnection from ansible.plugins.connection.docker import Connection as DockerConnection # from ansible.plugins.connection.winrm import Connection as WinRmConnection from ansible.plugins.connection.network_cli import Connection as NetworkCliConnection +from ansible.plugins.connection.httpapi import Connection as HttpapiConnection PY3 = sys.version_info[0] == 3 @@ -162,11 +163,16 @@ class TestConnectionBaseClass(unittest.TestCase): # self.assertIsInstance(WinRmConnection(), WinRmConnection) def test_network_cli_connection_module(self): + self.play_context.network_os = 'eos' self.assertIsInstance(NetworkCliConnection(self.play_context, self.in_stream), NetworkCliConnection) def test_netconf_connection_module(self): self.assertIsInstance(NetconfConnection(self.play_context, self.in_stream), NetconfConnection) + def test_httpapi_connection_module(self): + self.play_context.network_os = 'eos' + self.assertIsInstance(HttpapiConnection(self.play_context, self.in_stream), HttpapiConnection) + def test_check_password_prompt(self): local = ( b'[sudo via ansible, key=ouzmdnewuhucvuaabtjmweasarviygqq] password: \n' diff --git a/test/units/plugins/connection/test_netconf.py b/test/units/plugins/connection/test_netconf.py index 850bf8898a0..ecb3d9984de 100644 --- a/test/units/plugins/connection/test_netconf.py +++ b/test/units/plugins/connection/test_netconf.py @@ -58,9 +58,7 @@ class TestNetconfConnectionClass(unittest.TestCase): def test_netconf_init(self): pc = PlayContext() - new_stdin = StringIO() - - conn = netconf.Connection(pc, new_stdin) + conn = connection_loader.get('netconf', pc, '/dev/null') self.assertEqual('default', conn._network_os) self.assertIsNone(conn._manager) @@ -69,14 +67,11 @@ class TestNetconfConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.netconf.netconf_loader") def test_netconf__connect(self, mock_netconf_loader): pc = PlayContext() - new_stdin = StringIO() - - conn = connection_loader.get('netconf', pc, new_stdin) + conn = connection_loader.get('netconf', pc, '/dev/null') mock_manager = MagicMock() mock_manager.session_id = '123456789' netconf.manager.connect = MagicMock(return_value=mock_manager) - conn._play_context.network_os = 'default' rc, out, err = conn._connect() @@ -87,9 +82,8 @@ class TestNetconfConnectionClass(unittest.TestCase): def test_netconf_exec_command(self): pc = PlayContext() - new_stdin = StringIO() + conn = connection_loader.get('netconf', pc, '/dev/null') - conn = netconf.Connection(pc, new_stdin) conn._connected = True mock_reply = MagicMock(name='reply') @@ -105,9 +99,8 @@ class TestNetconfConnectionClass(unittest.TestCase): def test_netconf_exec_command_invalid_request(self): pc = PlayContext() - new_stdin = StringIO() + conn = connection_loader.get('netconf', pc, '/dev/null') - conn = netconf.Connection(pc, new_stdin) conn._connected = True mock_manager = MagicMock(name='self._manager') diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index 9831b5d6e6b..d6938c51f2e 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -30,7 +30,6 @@ from ansible.compat.tests.mock import patch, MagicMock from ansible.errors import AnsibleConnectionFailure from ansible.playbook.play_context import PlayContext -from ansible.plugins.connection import network_cli from ansible.plugins.loader import connection_loader @@ -39,39 +38,30 @@ class TestConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") def test_network_cli__connect_error(self, mocked_super): pc = PlayContext() - new_stdin = StringIO() - + pc.network_os = 'ios' conn = connection_loader.get('network_cli', pc, '/dev/null') + conn.ssh = MagicMock() conn.receive = MagicMock() - conn._terminal = MagicMock() - pc.network_os = None + conn._network_os = 'does not exist' + self.assertRaises(AnsibleConnectionFailure, conn._connect) - @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") - def test_network_cli__invalid_os(self, mocked_super): + def test_network_cli__invalid_os(self): pc = PlayContext() - new_stdin = StringIO() - - conn = connection_loader.get('network_cli', pc, '/dev/null') - conn.ssh = MagicMock() - conn.receive = MagicMock() - conn._terminal = MagicMock() pc.network_os = None - self.assertRaises(AnsibleConnectionFailure, conn._connect) + + self.assertRaises(AnsibleConnectionFailure, connection_loader.get, 'network_cli', pc, '/dev/null') @patch("ansible.plugins.connection.network_cli.terminal_loader") @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") def test_network_cli__connect(self, mocked_super, mocked_terminal_loader): pc = PlayContext() pc.network_os = 'ios' - new_stdin = StringIO() - conn = connection_loader.get('network_cli', pc, '/dev/null') conn.ssh = MagicMock() conn.receive = MagicMock() - conn._terminal = MagicMock() conn._connect() self.assertTrue(conn._terminal.on_open_shell.called) @@ -88,8 +78,8 @@ class TestConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.paramiko_ssh.Connection.close") def test_network_cli_close(self, mocked_super): pc = PlayContext() - new_stdin = StringIO() - conn = network_cli.Connection(pc, new_stdin) + pc.network_os = 'ios' + conn = connection_loader.get('network_cli', pc, '/dev/null') terminal = MagicMock(supports_multiplexing=False) conn._terminal = terminal @@ -105,8 +95,8 @@ class TestConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.paramiko_ssh.Connection._connect") def test_network_cli_exec_command(self, mocked_super): pc = PlayContext() - new_stdin = StringIO() - conn = network_cli.Connection(pc, new_stdin) + pc.network_os = 'ios' + conn = connection_loader.get('network_cli', pc, '/dev/null') mock_send = MagicMock(return_value=b'command response') conn.send = mock_send @@ -124,8 +114,9 @@ class TestConnectionClass(unittest.TestCase): def test_network_cli_send(self): pc = PlayContext() - new_stdin = StringIO() - conn = network_cli.Connection(pc, new_stdin) + pc.network_os = 'ios' + conn = connection_loader.get('network_cli', pc, '/dev/null') + mock__terminal = MagicMock() mock__terminal.terminal_stdout_re = [re.compile(b'device#')] mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]