Extend SSH Retry to put_file and fetch_file (#20187)

* Move retry logic into _ssh_retry decorator, and apply to exec_command, put_file and fetch_file

* Update tests to reflect change

* Move _ssh_retry to _run, and update tests to reflect

* piped should use exec_command instead of removed _exec_command

* Rework tests to support selectors instead of select.select
pull/22217/head
Matt Martz 7 years ago committed by James Cammarata
parent 911600acf9
commit 1fe67f9f43

@ -29,6 +29,7 @@ import socket
import subprocess import subprocess
import time import time
from functools import wraps
from ansible import constants as C from ansible import constants as C
from ansible.compat import selectors from ansible.compat import selectors
from ansible.compat.six import PY3, text_type, binary_type from ansible.compat.six import PY3, text_type, binary_type
@ -51,6 +52,54 @@ except ImportError:
SSHPASS_AVAILABLE = None SSHPASS_AVAILABLE = None
def _ssh_retry(func):
"""
Decorator to retry ssh/scp/sftp in the case of a connection failure
Will retry if:
* an exception is caught
* ssh returns 255
Will not retry if
* remaining_tries is <2
* retries limit reached
"""
@wraps(func)
def wrapped(self, *args, **kwargs):
remaining_tries = int(C.ANSIBLE_SSH_RETRIES) + 1
cmd_summary = "%s..." % args[0]
for attempt in range(remaining_tries):
try:
return_tuple = func(self, *args, **kwargs)
display.vvv(return_tuple, host=self.host)
# 0 = success
# 1-254 = remote command return code
# 255 = failure from the ssh command itself
if return_tuple[0] != 255:
break
else:
raise AnsibleConnectionFailure("Failed to connect to the host via ssh: %s" % to_native(return_tuple[2]))
except (AnsibleConnectionFailure, Exception) as e:
if attempt == remaining_tries - 1:
raise
else:
pause = 2 ** attempt - 1
if pause > 30:
pause = 30
if isinstance(e, AnsibleConnectionFailure):
msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause)
else:
msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause)
display.vv(msg, host=self.host)
time.sleep(pause)
continue
return return_tuple
return wrapped
class Connection(ConnectionBase): class Connection(ConnectionBase):
''' ssh based connections ''' ''' ssh based connections '''
@ -352,6 +401,7 @@ class Connection(ConnectionBase):
return b''.join(output), remainder return b''.join(output), remainder
@_ssh_retry
def _run(self, cmd, in_data, sudoable=True, checkrc=True): def _run(self, cmd, in_data, sudoable=True, checkrc=True):
''' '''
Starts the command and communicates with it until it ends. Starts the command and communicates with it until it ends.
@ -618,28 +668,6 @@ class Connection(ConnectionBase):
return (p.returncode, b_stdout, b_stderr) return (p.returncode, b_stdout, b_stderr)
def _exec_command(self, cmd, in_data=None, sudoable=True):
''' run a command on the remote host '''
super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
display.vvv(u"ESTABLISH SSH CONNECTION FOR USER: {0}".format(self._play_context.remote_user), host=self._play_context.remote_addr)
# we can only use tty when we are not pipelining the modules. piping
# data into /usr/bin/python inside a tty automatically invokes the
# python interactive-mode but the modules are not compatible with the
# interactive-mode ("unexpected indent" mainly because of empty lines)
if not in_data and sudoable:
args = ('ssh', '-tt', self.host, cmd)
else:
args = ('ssh', self.host, cmd)
cmd = self._build_command(*args)
(returncode, stdout, stderr) = self._run(cmd, in_data, sudoable=sudoable)
return (returncode, stdout, stderr)
def _file_transport_command(self, in_path, out_path, sftp_action): def _file_transport_command(self, in_path, out_path, sftp_action):
# scp and sftp require square brackets for IPv6 addresses, but # scp and sftp require square brackets for IPv6 addresses, but
# accept them for hostnames and IPv4 addresses too. # accept them for hostnames and IPv4 addresses too.
@ -674,7 +702,6 @@ class Connection(ConnectionBase):
methods = ['sftp'] methods = ['sftp']
success = False success = False
res = None
for method in methods: for method in methods:
returncode = stdout = stderr = None returncode = stdout = stderr = None
if method == 'sftp': if method == 'sftp':
@ -693,77 +720,58 @@ class Connection(ConnectionBase):
if sftp_action == 'get': if sftp_action == 'get':
# we pass sudoable=False to disable pty allocation, which # we pass sudoable=False to disable pty allocation, which
# would end up mixing stdout/stderr and screwing with newlines # would end up mixing stdout/stderr and screwing with newlines
(returncode, stdout, stderr) = self._exec_command('dd if=%s bs=%s' % (in_path, BUFSIZE), sudoable=False) (returncode, stdout, stderr) = self.exec_command('dd if=%s bs=%s' % (in_path, BUFSIZE), sudoable=False)
out_file = open(to_bytes(out_path, errors='surrogate_or_strict'), 'wb+') out_file = open(to_bytes(out_path, errors='surrogate_or_strict'), 'wb+')
out_file.write(stdout) out_file.write(stdout)
out_file.close() out_file.close()
else: else:
in_data = open(to_bytes(in_path, errors='surrogate_or_strict'), 'rb').read() in_data = open(to_bytes(in_path, errors='surrogate_or_strict'), 'rb').read()
in_data = to_bytes(in_data, nonstring='passthru') in_data = to_bytes(in_data, nonstring='passthru')
(returncode, stdout, stderr) = self._exec_command('dd of=%s bs=%s' % (out_path, BUFSIZE), in_data=in_data) (returncode, stdout, stderr) = self.exec_command('dd of=%s bs=%s' % (out_path, BUFSIZE), in_data=in_data)
# Check the return code and rollover to next method if failed # Check the return code and rollover to next method if failed
if returncode == 0: if returncode == 0:
success = True return (returncode, stdout, stderr)
break
else: else:
# If not in smart mode, the data will be printed by the raise below # If not in smart mode, the data will be printed by the raise below
if len(methods) > 1: if len(methods) > 1:
display.warning(msg='%s transfer mechanism failed on %s. Use ANSIBLE_DEBUG=1 to see detailed information' % (method, host)) display.warning(msg='%s transfer mechanism failed on %s. Use ANSIBLE_DEBUG=1 to see detailed information' % (method, host))
display.debug(msg='%s' % to_native(stdout)) display.debug(msg='%s' % to_native(stdout))
display.debug(msg='%s' % to_native(stderr)) display.debug(msg='%s' % to_native(stderr))
res = (returncode, stdout, stderr)
if not success: if returncode == 255:
raise AnsibleError("failed to transfer file {0} to {1}:\n{2}\n{3}"\ raise AnsibleConnectionFailure("Failed to connect to the host via %s: %s" % (method, to_native(stderr)))
.format(to_native(in_path), to_native(out_path), to_native(res[1]), to_native(res[2]))) else:
raise AnsibleError("failed to transfer file to {0} {1}:\n{2}\n{3}"\
.format(to_native(in_path), to_native(out_path), to_native(stdout), to_native(stderr)))
# #
# Main public methods # Main public methods
# #
def exec_command(self, *args, **kwargs): def exec_command(self, cmd, in_data=None, sudoable=True):
""" ''' run a command on the remote host '''
Wrapper around _exec_command to retry in the case of an ssh failure
Will retry if:
* an exception is caught
* ssh returns 255
Will not retry if
* remaining_tries is <2
* retries limit reached
"""
remaining_tries = int(C.ANSIBLE_SSH_RETRIES) + 1 super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable)
cmd_summary = "%s..." % args[0]
for attempt in range(remaining_tries):
try:
return_tuple = self._exec_command(*args, **kwargs)
# 0 = success
# 1-254 = remote command return code
# 255 = failure from the ssh command itself
if return_tuple[0] != 255:
break
else:
raise AnsibleConnectionFailure("Failed to connect to the host via ssh: %s" % to_native(return_tuple[2]))
except (AnsibleConnectionFailure, Exception) as e:
if attempt == remaining_tries - 1:
raise
else:
pause = 2 ** attempt - 1
if pause > 30:
pause = 30
if isinstance(e, AnsibleConnectionFailure): display.vvv(u"ESTABLISH SSH CONNECTION FOR USER: {0}".format(self._play_context.remote_user), host=self._play_context.remote_addr)
msg = "ssh_retry: attempt: %d, ssh return code is 255. cmd (%s), pausing for %d seconds" % (attempt, cmd_summary, pause)
else:
msg = "ssh_retry: attempt: %d, caught exception(%s) from cmd (%s), pausing for %d seconds" % (attempt, e, cmd_summary, pause)
display.vv(msg, host=self.host)
time.sleep(pause) # we can only use tty when we are not pipelining the modules. piping
continue # data into /usr/bin/python inside a tty automatically invokes the
# python interactive-mode but the modules are not compatible with the
# interactive-mode ("unexpected indent" mainly because of empty lines)
return return_tuple ssh_executable = self._play_context.ssh_executable
if not in_data and sudoable:
args = (ssh_executable, '-tt', self.host, cmd)
else:
args = (ssh_executable, self.host, cmd)
cmd = self._build_command(*args)
(returncode, stdout, stderr) = self._run(cmd, in_data, sudoable=sudoable)
return (returncode, stdout, stderr)
def put_file(self, in_path, out_path): def put_file(self, in_path, out_path):
''' transfer a file from local to remote ''' ''' transfer a file from local to remote '''
@ -774,7 +782,7 @@ class Connection(ConnectionBase):
if not os.path.exists(to_bytes(in_path, errors='surrogate_or_strict')): if not os.path.exists(to_bytes(in_path, errors='surrogate_or_strict')):
raise AnsibleFileNotFound("file or module does not exist: {0}".format(to_native(in_path))) raise AnsibleFileNotFound("file or module does not exist: {0}".format(to_native(in_path)))
self._file_transport_command(in_path, out_path, 'put') return self._file_transport_command(in_path, out_path, 'put')
def fetch_file(self, in_path, out_path): def fetch_file(self, in_path, out_path):
''' fetch a file from remote to local ''' ''' fetch a file from remote to local '''
@ -782,7 +790,7 @@ class Connection(ConnectionBase):
super(Connection, self).fetch_file(in_path, out_path) super(Connection, self).fetch_file(in_path, out_path)
display.vvv(u"FETCH {0} TO {1}".format(in_path, out_path), host=self.host) display.vvv(u"FETCH {0} TO {1}".format(in_path, out_path), host=self.host)
self._file_transport_command(in_path, out_path, 'get') return self._file_transport_command(in_path, out_path, 'get')
def reset(self): def reset(self):
# If we have a persistent ssh connection (ControlPersist), we can ask it to stop listening. # If we have a persistent ssh connection (ControlPersist), we can ask it to stop listening.

@ -25,7 +25,7 @@ from io import StringIO
import pytest import pytest
from ansible.compat.tests import unittest from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock from ansible.compat.tests.mock import patch, MagicMock, PropertyMock
from ansible import constants as C from ansible import constants as C
from ansible.compat.selectors import SelectorKey, EVENT_READ from ansible.compat.selectors import SelectorKey, EVENT_READ
@ -72,7 +72,7 @@ class TestConnectionBaseClass(unittest.TestCase):
conn = ssh.Connection(pc, new_stdin) conn = ssh.Connection(pc, new_stdin)
conn._build_command('ssh') conn._build_command('ssh')
def test_plugins_connection_ssh__exec_command(self): def test_plugins_connection_ssh_exec_command(self):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin) conn = ssh.Connection(pc, new_stdin)
@ -82,8 +82,8 @@ class TestConnectionBaseClass(unittest.TestCase):
conn._run = MagicMock() conn._run = MagicMock()
conn._run.return_value = (0, 'stdout', 'stderr') conn._run.return_value = (0, 'stdout', 'stderr')
res, stdout, stderr = conn._exec_command('ssh') res, stdout, stderr = conn.exec_command('ssh')
res, stdout, stderr = conn._exec_command('ssh', 'this is some data') res, stdout, stderr = conn.exec_command('ssh', 'this is some data')
def test_plugins_connection_ssh__examine_output(self): def test_plugins_connection_ssh__examine_output(self):
pc = PlayContext() pc = PlayContext()
@ -193,36 +193,8 @@ class TestConnectionBaseClass(unittest.TestCase):
self.assertTrue(conn._flags['become_nopasswd_error']) self.assertTrue(conn._flags['become_nopasswd_error'])
@patch('time.sleep') @patch('time.sleep')
def test_plugins_connection_ssh_exec_command(self, mock_sleep):
pc = PlayContext()
new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin)
conn._build_command = MagicMock()
conn._exec_command = MagicMock()
C.ANSIBLE_SSH_RETRIES = 9
# test a regular, successful execution
conn._exec_command.return_value = (0, b'stdout', b'')
res = conn.exec_command('ssh', 'some data')
self.assertEquals(res, (0, b'stdout', b''), msg='exec_command did not return what the _exec_command helper returned')
# test a retry, followed by success
conn._exec_command.return_value = None
conn._exec_command.side_effect = [(255, '', ''), (0, b'stdout', b'')]
res = conn.exec_command('ssh', 'some data')
self.assertEquals(res, (0, b'stdout', b''), msg='exec_command did not return what the _exec_command helper returned')
# test multiple failures
conn._exec_command.side_effect = [(255, b'', b'')] * 10
self.assertRaises(AnsibleConnectionFailure, conn.exec_command, 'ssh', 'some data')
# test other failure from exec_command
conn._exec_command.side_effect = [Exception('bad')] * 10
self.assertRaises(Exception, conn.exec_command, 'ssh', 'some data')
@patch('os.path.exists') @patch('os.path.exists')
def test_plugins_connection_ssh_put_file(self, mock_ospe): def test_plugins_connection_ssh_put_file(self, mock_ospe, mock_sleep):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin) conn = ssh.Connection(pc, new_stdin)
@ -234,6 +206,8 @@ class TestConnectionBaseClass(unittest.TestCase):
conn._run.return_value = (0, '', '') conn._run.return_value = (0, '', '')
conn.host = "some_host" conn.host = "some_host"
C.ANSIBLE_SSH_RETRIES = 9
# Test with C.DEFAULT_SCP_IF_SSH set to smart # Test with C.DEFAULT_SCP_IF_SSH set to smart
# Test when SFTP works # Test when SFTP works
C.DEFAULT_SCP_IF_SSH = 'smart' C.DEFAULT_SCP_IF_SSH = 'smart'
@ -276,7 +250,8 @@ class TestConnectionBaseClass(unittest.TestCase):
conn._run.return_value = (0, 'stdout', '') conn._run.return_value = (0, 'stdout', '')
self.assertRaises(AnsibleFileNotFound, conn.put_file, '/path/to/bad/file', '/remote/path/to/file') self.assertRaises(AnsibleFileNotFound, conn.put_file, '/path/to/bad/file', '/remote/path/to/file')
def test_plugins_connection_ssh_fetch_file(self): @patch('time.sleep')
def test_plugins_connection_ssh_fetch_file(self, mock_sleep):
pc = PlayContext() pc = PlayContext()
new_stdin = StringIO() new_stdin = StringIO()
conn = ssh.Connection(pc, new_stdin) conn = ssh.Connection(pc, new_stdin)
@ -287,6 +262,8 @@ class TestConnectionBaseClass(unittest.TestCase):
conn._run.return_value = (0, '', '') conn._run.return_value = (0, '', '')
conn.host = "some_host" conn.host = "some_host"
C.ANSIBLE_SSH_RETRIES = 9
# Test with C.DEFAULT_SCP_IF_SSH set to smart # Test with C.DEFAULT_SCP_IF_SSH set to smart
# Test when SFTP works # Test when SFTP works
C.DEFAULT_SCP_IF_SSH = 'smart' C.DEFAULT_SCP_IF_SSH = 'smart'
@ -535,3 +512,120 @@ class TestSSHConnectionRun(object):
assert self.mock_selector.register.called is True assert self.mock_selector.register.called is True
assert self.mock_selector.register.call_count == 2 assert self.mock_selector.register.call_count == 2
assert self.conn._send_initial_data.called is False assert self.conn._send_initial_data.called is False
@pytest.mark.usefixtures('mock_run_env')
class TestSSHConnectionRetries(object):
def test_retry_then_success(self):
self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4)
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[]
]
self.mock_selector.get_map.side_effect = lambda: True
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = 'ssh'
return_code, b_stdout, b_stderr = self.conn.exec_command('ssh', 'some data')
assert return_code == 0
assert b_stdout == b'my_stdout\nsecond_line'
assert b_stderr == b'my_stderr'
@patch('time.sleep')
def test_multiple_failures(self, mock_sleep):
C.ANSIBLE_SSH_RETRIES = 9
self.mock_popen_res.stdout.read.side_effect = [b""] * 11
self.mock_popen_res.stderr.read.side_effect = [b""] * 11
type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 30)
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[],
] * 10
self.mock_selector.get_map.side_effect = lambda: True
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = 'ssh'
pytest.raises(AnsibleConnectionFailure, self.conn.exec_command, 'ssh', 'some data')
assert self.mock_popen.call_count == 10
@patch('time.sleep')
def test_abitrary_exceptions(self, mock_sleep):
C.ANSIBLE_SSH_RETRIES = 9
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = 'ssh'
self.mock_popen.side_effect = [Exception('bad')] * 10
pytest.raises(Exception, self.conn.exec_command, 'ssh', 'some data')
assert self.mock_popen.call_count == 10
@patch('time.sleep')
@patch('ansible.plugins.connection.ssh.os')
def test_put_file_retries(self, os_mock, time_mock):
os_mock.path.exists.return_value = True
self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4)
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[]
]
self.mock_selector.get_map.side_effect = lambda: True
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = 'ssh'
return_code, b_stdout, b_stderr = self.conn.put_file('/path/to/in/file', '/path/to/dest/file')
assert return_code == 0
assert b_stdout == b"my_stdout\nsecond_line"
assert b_stderr == b"my_stderr"
assert self.mock_popen.call_count == 2
@patch('time.sleep')
@patch('ansible.plugins.connection.ssh.os')
def test_fetch_file_retries(self, os_mock, time_mock):
os_mock.path.exists.return_value = True
self.mock_popen_res.stdout.read.side_effect = [b"", b"my_stdout\n", b"second_line"]
self.mock_popen_res.stderr.read.side_effect = [b"", b"my_stderr"]
type(self.mock_popen_res).returncode = PropertyMock(side_effect=[255] * 3 + [0] * 4)
self.mock_selector.select.side_effect = [
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stdout, 1001, [EVENT_READ], None), EVENT_READ)],
[(SelectorKey(self.mock_popen_res.stderr, 1002, [EVENT_READ], None), EVENT_READ)],
[]
]
self.mock_selector.get_map.side_effect = lambda: True
self.conn._build_command = MagicMock()
self.conn._build_command.return_value = 'ssh'
return_code, b_stdout, b_stderr = self.conn.fetch_file('/path/to/in/file', '/path/to/dest/file')
assert return_code == 0
assert b_stdout == b"my_stdout\nsecond_line"
assert b_stderr == b"my_stderr"
assert self.mock_popen.call_count == 2

Loading…
Cancel
Save