testlib: Wait for sshd before running tests

On Ubuntu 17.10 something (probably Docker) appears to be accepting
connections, before sshd is fully ready. This results in a race
condition, and hence connection errors for the first few tests (2-3 on
my laptop).

testlib.wait_for_port() checks not only that the port can be connected
to, but also something resembling the sshd banner is sent.

Fixes #51
wip-fakessh-exit-status
Alex Willmer 7 years ago committed by David Wilson
parent f8a84616d7
commit dfc7b85504

@ -1,7 +1,10 @@
import os import os
import random import random
import re
import socket
import sys import sys
import time
import unittest import unittest
import urlparse import urlparse
@ -27,6 +30,75 @@ def data_path(suffix):
return path return path
def wait_for_port(
host,
port,
pattern=None,
connect_timeout=0.5,
receive_timeout=0.5,
overall_timeout=5.0,
sleep=0.1,
):
"""Attempt to connect to host/port, for upto overall_timeout seconds.
If a regex pattern is supplied try to find it in the initial data.
Return None on success, or raise on error.
"""
start = time.time()
end = start + overall_timeout
addr = (host, port)
while time.time() < end:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(connect_timeout)
try:
sock.connect(addr)
except socket.error:
# Failed to connect. So wait then retry.
time.sleep(sleep)
continue
if not pattern:
# Success: We connected & there's no banner check to perform.
sock.shutdown(socket.SHUTD_RDWR)
sock.close()
return
sock.settimeout(receive_timeout)
data = ''
found = False
while time.time() < end:
try:
resp = sock.recv(1024)
except socket.timeout:
# Server stayed up, but had no data. Retry the recv().
continue
if not resp:
# Server went away. Wait then retry the connection.
time.sleep(sleep)
break
data += resp
if re.search(pattern, data):
found = True
break
sock.shutdown(socket.SHUT_RDWR)
sock.close()
if found:
# Success: We received the banner & found the desired pattern
return
else:
# Failure: The overall timeout expired
if pattern:
raise socket.timeout('Timed out while searching for %r from %s:%s'
% (pattern, host, port))
else:
raise socket.timeout('Timed out while connecting to %s:%s'
% (host, port))
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
def assertRaises(self, exc, func, *args, **kwargs): def assertRaises(self, exc, func, *args, **kwargs):
"""Like regular assertRaises, except return the exception that was """Like regular assertRaises, except return the exception that was
@ -61,6 +133,9 @@ class DockerizedSshDaemon(object):
parsed = urlparse.urlparse(self.docker.api.base_url) parsed = urlparse.urlparse(self.docker.api.base_url)
return parsed.netloc.partition(':')[0] return parsed.netloc.partition(':')[0]
def wait_for_sshd(self):
wait_for_port(self.get_host(), int(self.port), pattern='OpenSSH')
def close(self): def close(self):
self.container.stop() self.container.stop()
self.container.remove() self.container.remove()
@ -86,6 +161,7 @@ class DockerMixin(RouterMixin):
def setUpClass(cls): def setUpClass(cls):
super(DockerMixin, cls).setUpClass() super(DockerMixin, cls).setUpClass()
cls.dockerized_ssh = DockerizedSshDaemon() cls.dockerized_ssh = DockerizedSshDaemon()
cls.dockerized_ssh.wait_for_sshd()
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):

Loading…
Cancel
Save