@ -1,7 +1,10 @@
import os
import random
import re
import socket
import sys
import time
import unittest
import urlparse
@ -27,6 +30,75 @@ def data_path(suffix):
return path
def wait_for_port (
host ,
port ,
pattern = None ,
connect_timeout = 0.5 ,
receive_timeout = 0.5 ,
overall_timeout = 5.0 ,
sleep = 0.1 ,
) :
""" Attempt to connect to host/port, for upto overall_timeout seconds.
If a regex pattern is supplied try to find it in the initial data .
Return None on success , or raise on error .
"""
start = time . time ( )
end = start + overall_timeout
addr = ( host , port )
while time . time ( ) < end :
sock = socket . socket ( socket . AF_INET , socket . SOCK_STREAM )
sock . settimeout ( connect_timeout )
try :
sock . connect ( addr )
except socket . error :
# Failed to connect. So wait then retry.
time . sleep ( sleep )
continue
if not pattern :
# Success: We connected & there's no banner check to perform.
sock . shutdown ( socket . SHUTD_RDWR )
sock . close ( )
return
sock . settimeout ( receive_timeout )
data = ' '
found = False
while time . time ( ) < end :
try :
resp = sock . recv ( 1024 )
except socket . timeout :
# Server stayed up, but had no data. Retry the recv().
continue
if not resp :
# Server went away. Wait then retry the connection.
time . sleep ( sleep )
break
data + = resp
if re . search ( pattern , data ) :
found = True
break
sock . shutdown ( socket . SHUT_RDWR )
sock . close ( )
if found :
# Success: We received the banner & found the desired pattern
return
else :
# Failure: The overall timeout expired
if pattern :
raise socket . timeout ( ' Timed out while searching for %r from %s : %s '
% ( pattern , host , port ) )
else :
raise socket . timeout ( ' Timed out while connecting to %s : %s '
% ( host , port ) )
class TestCase ( unittest . TestCase ) :
def assertRaises ( self , exc , func , * args , * * kwargs ) :
""" Like regular assertRaises, except return the exception that was
@ -61,6 +133,9 @@ class DockerizedSshDaemon(object):
parsed = urlparse . urlparse ( self . docker . api . base_url )
return parsed . netloc . partition ( ' : ' ) [ 0 ]
def wait_for_sshd ( self ) :
wait_for_port ( self . get_host ( ) , int ( self . port ) , pattern = ' OpenSSH ' )
def close ( self ) :
self . container . stop ( )
self . container . remove ( )
@ -86,6 +161,7 @@ class DockerMixin(RouterMixin):
def setUpClass ( cls ) :
super ( DockerMixin , cls ) . setUpClass ( )
cls . dockerized_ssh = DockerizedSshDaemon ( )
cls . dockerized_ssh . wait_for_sshd ( )
@classmethod
def tearDownClass ( cls ) :