From dc23667cc2d56e243ba97436d67120ccfbb9d78c Mon Sep 17 00:00:00 2001 From: Peter Sprygada Date: Mon, 5 Dec 2016 21:42:09 -0500 Subject: [PATCH] add back reverted change to network_cli (#18761) This adds back the change to the network_cli plugin. Ths change adds the ensure_connect decorator to the open_shell() method to make sure the connection is valid before trying to open a shell. The issue was due to the addition of the decorator that will call _connect() when there is no connection. The _connect() method should have been mocked in the test case. This commit fixes the test case as well Change was originally reverted in https://github.com/ansible/ansible/commit/c414ded69a6b12fc22b4b2756e24cce9926bc84c --- lib/ansible/plugins/connection/network_cli.py | 43 ++++++++++++++++--- .../plugins/connection/test_network_cli.py | 19 +++++--- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py index 403358738ea..b65c2ba7b6a 100644 --- a/lib/ansible/plugins/connection/network_cli.py +++ b/lib/ansible/plugins/connection/network_cli.py @@ -27,11 +27,12 @@ import datetime from ansible.errors import AnsibleConnectionFailure from ansible.module_utils.six.moves import StringIO from ansible.plugins import terminal_loader +from ansible.plugins.connection import ensure_connect from ansible.plugins.connection.paramiko_ssh import Connection as _Connection class Connection(_Connection): - ''' CLI SSH based connections on Paramiko ''' + ''' CLI (shell) SSH connections on Paramiko ''' transport = 'network_cli' has_pipelining = False @@ -47,6 +48,7 @@ class Connection(_Connection): self._history = list() def update_play_context(self, play_context): + """Updates the play context information for the connection""" if self._play_context.become is False and play_context.become is True: auth_pass = play_context.become_pass self._terminal.on_authorize(passwd=auth_pass) @@ -57,6 +59,7 @@ class Connection(_Connection): self._play_context = play_context def _connect(self): + """Connections to the device and sets the terminal type""" super(Connection, self)._connect() network_os = self._play_context.network_os @@ -78,7 +81,9 @@ class Connection(_Connection): return (0, 'connected', '') - def open_shell(self, timeout=10): + @ensure_connect + def open_shell(self): + """Opens the vty shell on the connection""" self._shell = self.ssh.invoke_shell() self._shell.settimeout(self._play_context.timeout) @@ -87,16 +92,16 @@ class Connection(_Connection): if self._shell: self._terminal.on_open_shell() - if hasattr(self._play_context, 'become'): - if self._play_context.become: - auth_pass = self._play_context.become_pass - self._terminal.on_authorize(passwd=auth_pass) + if getattr(self._play_context, 'become', None): + auth_pass = self._play_context.become_pass + self._terminal.on_authorize(passwd=auth_pass) def close(self): self.close_shell() super(Connection, self).close() def close_shell(self): + """Closes the vty shell if the device supports multiplexing""" if self._shell: self._terminal.on_close_shell() @@ -107,6 +112,7 @@ class Connection(_Connection): return (0, 'shell closed', '') def receive(self, obj=None): + """Handles receiving of output from command""" recv = StringIO() handled = False @@ -130,6 +136,7 @@ class Connection(_Connection): return self._sanitize(resp, obj) def send(self, obj): + """Sends the command to the device in the opened shell""" try: command = obj['command'] self._history.append(command) @@ -139,11 +146,13 @@ class Connection(_Connection): raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip()) def _strip(self, data): + """Removes ANSI codes from device response""" for regex in self._terminal.ansi_re: data = regex.sub('', data) return data def _handle_prompt(self, resp, obj): + """Matches the command prompt and responds""" prompt = re.compile(obj['prompt'], re.I) answer = obj['answer'] match = prompt.search(resp) @@ -152,6 +161,7 @@ class Connection(_Connection): return True def _sanitize(self, resp, obj=None): + """Removes elements from the response before returning to the caller""" cleaned = [] command = obj.get('command') if obj else None for line in resp.splitlines(): @@ -161,6 +171,7 @@ class Connection(_Connection): return str("\n".join(cleaned)).strip() def _find_prompt(self, response): + """Searches the buffered response for a matching command prompt""" for regex in self._terminal.terminal_errors_re: if regex.search(response): raise AnsibleConnectionFailure(response) @@ -173,10 +184,28 @@ class Connection(_Connection): return True def alarm_handler(self, signum, frame): + """Alarm handler raised in case of command timeout """ self.close_shell() def exec_command(self, cmd): - ''' {'command': , 'prompt': , 'answer': } ''' + """Executes the cmd on in the shell and returns the output + + The method accepts two forms of cmd. The first form is as a + string that represents the command to be executed in the shell. The + second form is as a JSON string with additional keyword. + + Keywords supported for cmd: + * command - the command string to execute + * prompt - the expected prompt generated by executing command + * response - the string to respond to the prompt with + + :arg cmd: the string that represents the command to be executed + which can be a single command or a json encoded string + :returns: a tuple of (return code, stdout, stderr). The return + code is an integer and stdout and stderr are strings + """ + # TODO: add support for timeout to the cmd to handle non return + # commands such as a system restart try: obj = json.loads(cmd) diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py index 627a0a144a8..fd3e0219051 100644 --- a/test/units/plugins/connection/test_network_cli.py +++ b/test/units/plugins/connection/test_network_cli.py @@ -33,7 +33,6 @@ from ansible.playbook.play_context import PlayContext from ansible.plugins.connection import network_cli - class TestConnectionClass(unittest.TestCase): @patch("ansible.plugins.connection.network_cli.terminal_loader") @@ -69,21 +68,27 @@ class TestConnectionClass(unittest.TestCase): conn.ssh = MagicMock() conn.receive = MagicMock() - terminal = MagicMock() - conn._terminal = terminal + mock_terminal = MagicMock() + conn._terminal = mock_terminal + + mock__connect = MagicMock() + conn._connect = mock__connect conn.open_shell() - self.assertTrue(terminal.on_open_shell.called) - self.assertFalse(terminal.on_authorize.called) + self.assertTrue(mock__connect.called) + self.assertTrue(mock_terminal.on_open_shell.called) + self.assertFalse(mock_terminal.on_authorize.called) + + mock_terminal.reset_mock() - terminal.reset_mock() conn._play_context.become = True conn._play_context.become_pass = 'password' conn.open_shell() - terminal.on_authorize.assert_called_with(passwd='password') + self.assertTrue(mock__connect.called) + mock_terminal.on_authorize.assert_called_with(passwd='password') def test_network_cli_close_shell(self): pc = PlayContext()