@ -61,11 +61,28 @@ class CorruptMessageError(StreamError):
class TimeoutError ( StreamError ) :
' Raised when a timeout occurs on a stream. '
class CallError ( ContextError ) :
' Raised when .Call() fails '
def __init__ ( self , e ) :
name = ' %s . %s ' % ( type ( e ) . __module__ , type ( e ) . __name__ )
stack = ' ' . join ( traceback . format_stack ( sys . exc_info [ 2 ] ) )
ContextError . __init__ ( self , ' Call failed: %s : %s \n %s ' , name , e , stack )
#
# Helpers.
#
class Dead ( object ) :
def __eq__ ( self , other ) :
return type ( other ) is Dead
def __repr__ ( self ) :
return ' <Dead> '
_DEAD = Dead ( )
def write_all ( fd , s ) :
written = 0
while written < len ( s ) :
@ -117,45 +134,32 @@ class Formatter(logging.Formatter):
class Channel ( object ) :
def __init__ ( self , stream , handle ) :
self . _context = stream . _context
self . _stream = stream
def __init__ ( self , context , handle ) :
self . _context = context
self . _handle = handle
self . _queue = Queue . Queue ( )
self . _context . AddHandleCB ( self . _ Internal Receive, handle )
self . _context . AddHandleCB ( self . _ Receive, handle )
def _ Internal Receive( self , killed , data ) :
def _ Receive( self , data ) :
"""
Callback from the stream object ; appends a tuple of
( killed - or - closed , data ) to the internal queue and wakes the internal
event .
Args :
# Has the Stream object lost its connection?
killed : bool
data : (
# Has the remote Channel had Close() called?
bool ,
# The object passed to the remote Send()
object
)
Callback from the Stream ; appends data to the internal queue .
"""
LOG . debug ( ' %r ._ Internal Receive(%r , %r )' , self , killed , data )
self . _queue . put ( ( killed or data [ 0 ] , killed or data [ 1 ] ) )
LOG . debug ( ' %r ._Receive( %r ) ' , self , data )
self . _queue . put ( data )
def Close ( self ) :
"""
Indicate this channel is closed to the remote side .
"""
LOG . debug ( ' %r .Close() ' , self )
self . _ stream. Enqueue ( handle , ( True , None ) )
self . _context . Enqueue ( handle , _DEAD )
def Send ( self , data ) :
"""
Send ` data ` to the remote .
"""
LOG . debug ( ' %r .Send( %r ) ' , self , data )
self . _ stream . Enqueue ( handle , ( False , data ) )
self . _ context . Enqueue ( handle , data )
def Receive ( self , timeout = None ) :
"""
@ -164,12 +168,12 @@ class Channel(object):
"""
LOG . debug ( ' %r .Receive( %r ) ' , self , timeout )
try :
killed, data = self . _queue . get ( True , timeout )
data = self . _queue . get ( True , timeout )
except Queue . Empty :
return
LOG . debug ( ' %r .Receive() got killed=%r , data= %r ' , self , killed , data )
if kille d:
LOG . debug ( ' %r .Receive() got %r ' , self , data )
if data == _DEAD :
raise ChannelError ( ' Channel is closed. ' )
return data
@ -185,7 +189,7 @@ class Channel(object):
return
def __repr__ ( self ) :
return ' econtext. Channel(%r , %r ) ' % ( self . _ stream , self . _handle )
return ' Channel(%r , %r ) ' % ( self . _ context , self . _handle )
class SlaveModuleImporter ( object ) :
@ -212,7 +216,7 @@ class SlaveModuleImporter(object):
if ret is None :
raise ImportError ( ' Master does not have %r ' % ( fullname , ) )
kind, path, data = ret
path, data = ret
code = compile ( zlib . decompress ( data ) , path , ' exec ' )
module = imp . new_module ( fullname )
sys . modules [ fullname ] = module
@ -223,30 +227,33 @@ class SlaveModuleImporter(object):
class MasterModuleResponder ( object ) :
def __init__ ( self , context ) :
self . _context = context
self . _context . AddHandleCB ( self . GetModule , handle = GET_MODULE )
def GetModule ( self , killed, data) :
if kille d:
def GetModule ( self , data) :
if data == _DEAD :
return
_ , ( reply_to , fullname ) = data
LOG . debug ( ' SlaveModuleImporter.GetModule( %r , %r ) ' , killed , fullname )
mod = sys . modules . get ( fullname )
if mod :
source = zlib . compress ( inspect . getsource ( mod ) )
path = os . path . abspath ( mod . __file__ )
self . _context . Enqueue ( reply_to , ( ' source ' , path , source ) )
reply_to , fullname = data
LOG . debug ( ' SlaveModuleImporter.GetModule( %r , %r ) ' , reply_to , fullname )
try :
module = __import__ ( fullname )
source = zlib . compress ( inspect . getsource ( module ) )
self . _context . Enqueue ( reply_to , ( module . __file__ , source ) )
except Exception , e :
LOG . exception ( ' While importing %r ' , fullname )
self . _context . Enqueue ( reply_to , None )
class LogForwarder ( object ) :
def __init__ ( self , context ) :
self . _context = context
self . _context . AddHandleCB ( self . ForwardLog , handle = FORWARD_LOG )
def ForwardLog ( self , killed, data) :
if kille d:
def ForwardLog ( self , data) :
if data == _DEAD :
return
_ , ( s , ) = data
LOG . debug ( ' %r : %s ' , self . _context , s )
LOG . debug ( ' %r : %s ' , self . _context , data )
#
@ -305,7 +312,7 @@ class Stream(BasicStream):
def Pickle ( self , obj ) :
"""
Serialize ` obj ` using the pickler .
Serialize ` obj ` into a bytestring .
"""
self . _pickler . dump ( obj )
data = self . _pickler_file . getvalue ( )
@ -315,15 +322,14 @@ class Stream(BasicStream):
def Unpickle ( self , data ) :
"""
Un serialize ` data ` into an object using the unpickler .
De serialize ` data ` into an object .
"""
LOG . debug ( ' %r .Unpickle( %r ) ' , self , data )
self . _unpickler_file . write( data )
self . _unpickler_file . truncate( 0 )
self . _unpickler_file . seek ( 0 )
data = self . _unpickler . load ( )
self . _unpickler_file . write ( data )
self . _unpickler_file . seek ( 0 )
self . _unpickler_file . truncate ( 0 )
return data
return self . _unpickler . load ( )
def Receive ( self ) :
"""
@ -349,27 +355,28 @@ class Stream(BasicStream):
self . _rhmac . update ( self . _input_buf [ 20 : msg_len + 24 ] )
expected_mac = self . _rhmac . digest ( )
if msg_mac != expected_mac :
raise CorruptMessageError ( ' %r got invalid MAC: expected %r , got %r ' ,
raise CorruptMessageError ( ' %r invalid MAC: expected %r , got %r ' ,
self , msg_mac . encode ( ' hex ' ) ,
expected_mac . encode ( ' hex ' ) )
try :
handle , data = self . Unpickle ( self . _input_buf [ 24 : msg_len + 24 ] )
self . _input_buf = self . _input_buf [ msg_len + 24 : ]
handle = long ( handle )
except ( TypeError , ValueError ) , ex :
raise CorruptMessageError ( ' %r got invalid message: %s ' , self , ex )
self . _input_buf = self . _input_buf [ msg_len + 24 : ]
self . _Invoke ( handle , data )
LOG . debug ( ' %r .Receive(): decoded handle= %r ; data= %r ' ,
self , handle , data )
def _Invoke ( self , handle , data ) :
LOG . debug ( ' %r ._Invoke(): handle= %r ; data= %r ' , self , handle , data )
try :
persist , fn = self . _context . _handle_map [ handle ]
if not persist :
del self . _context . _handle_map [ handle ]
except KeyError , ex :
raise CorruptMessageError ( ' %r got invalid handle: %r ' , self , handle )
except ( TypeError , ValueError ) , ex :
raise CorruptMessageError ( ' %r got invalid message: %s ' , self , ex )
LOG . debug ( ' Calling %r ( %r , %r ) ' , fn , False , data )
fn ( False , data )
if not persist :
del self . _context . _handle_map [ handle ]
fn ( data )
def Transmit ( self ) :
"""
@ -424,13 +431,12 @@ class Stream(BasicStream):
self . write_side . fd = None
for handle , ( persist , fn ) in self . _context . _handle_map . iteritems ( ) :
LOG . debug ( ' %r .Disconnect(): killing %r : %r ' , self , handle , fn )
fn ( True , None )
fn ( _DEAD )
def Accept ( self , rfd , wfd ) :
self . read_side = Side ( self , os . dup ( rfd ) )
self . write_side = Side ( self , os . dup ( wfd ) )
self . _context . SetStream ( self )
self . _context . broker . Register ( self . _context )
def Connect ( self ) :
"""
@ -444,22 +450,19 @@ class Stream(BasicStream):
self . Enqueue ( 0 , self . _context . name )
def __repr__ ( self ) :
return ' econtext. %s (<context= %r >) ' % \
( self . __class__ . __name__ , self . _context )
return ' %s (<context= %r >) ' % ( self . __class__ . __name__ , self . _context )
class LocalStream ( Stream ) :
"""
Base for streams capable of starting new slaves .
"""
python_path = property (
lambda self : getattr ( self , ' _python_path ' , sys . executable ) ,
lambda self , path : setattr ( self , ' _python_path ' , path ) ,
doc = ' The path to the remote Python interpreter. ' )
#: The path to the remote Python interpreter.
python_path = sys . executable
def __init__ ( self , context ) :
super ( LocalStream , self ) . __init__ ( context )
self . _permitted_ modules = set ( [ ' exceptions ' ] )
self . _permitted_ classes = set ( [ ( ' econtext.core ' , ' CallError ' ) ] )
self . _unpickler . find_global = self . _FindGlobal
def _FindGlobal ( self , module_name , class_name ) :
@ -467,16 +470,16 @@ class LocalStream(Stream):
Return the class implementing ` module_name . class_name ` or raise
` StreamError ` if the module is not whitelisted .
"""
if module_name not in self . _permitted_ modul es:
if ( module_name , class_name ) not in self . _permitted_ class es:
raise StreamError ( ' context %r attempted to unpickle %r in module %r ' ,
self . _context , class_name , module_name )
return getattr ( sys . modules [ module_name ] , class_name )
def Allow Module( self , module _name) :
def Allow Class( self , module_name , class _name) :
"""
Add ` module_name ` to the list of permitted modules .
"""
self . _permitted_modules . add ( module_name )
self . _permitted_modules . add ( ( module_name , class_name ) )
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
@ -517,10 +520,10 @@ class LocalStream(Stream):
self , self . read_side . fd )
source = inspect . getsource ( sys . modules [ __name__ ] )
source + = ' \n ExternalContext M ain(%r , %r , %r ) \n ' % (
source + = ' \n ExternalContext ().m ain(%r , %r , %r ) \n ' % (
self . _context . name ,
self . _context . broker. _listener . _listen_addr ,
self . _context . key
self . _context . key ,
self . _context . broker. log_level ,
)
compressed = zlib . compress ( source )
@ -530,10 +533,8 @@ class LocalStream(Stream):
class SSHStream ( LocalStream ) :
ssh_path = property (
lambda self : getattr ( self , ' _ssh_path ' , ' ssh ' ) ,
lambda self , path : setattr ( self , ' _ssh_path ' , path ) ,
doc = ' The path to the SSH binary. ' )
#: The path to the SSH binary.
ssh_path = ' ssh '
def GetBootCommand ( self ) :
bits = [ self . ssh_path ]
@ -563,10 +564,7 @@ class Context(object):
self . _lock = threading . Lock ( )
self . responder = MasterModuleResponder ( self )
self . AddHandleCB ( self . responder . GetModule , handle = GET_MODULE )
self . log_forwarder = LogForwarder ( self )
self . AddHandleCB ( self . log_forwarder . ForwardLog , handle = FORWARD_LOG )
def GetStream ( self ) :
return self . _stream
@ -577,10 +575,7 @@ class Context(object):
def AllocHandle ( self ) :
"""
Allocate a unique handle for this stream .
Returns :
long
Allocate a handle .
"""
self . _lock . acquire ( )
try :
@ -591,8 +586,8 @@ class Context(object):
def AddHandleCB ( self , fn , handle , persist = True ) :
"""
Register ` fn ( killed, obj) ` to run for each ` obj ` sent to ` handle ` . If
` persist ` is ` ` False ` ` then unregister after one delivery .
Register ` fn ( obj) ` to run for each ` obj ` sent to ` handle ` . If ` persist `
is ` ` False ` ` then unregister after one delivery .
"""
LOG . debug ( ' %r .AddHandleCB( %r , %r , persist= %r ) ' ,
self , fn , handle , persist )
@ -613,47 +608,43 @@ class Context(object):
queue = Queue . Queue ( )
def _Receive ( killed, data) :
LOG . debug ( ' %r ._Receive( %r , %r )' , self , killed , data )
queue . put ( ( killed , data ) )
def _Receive ( data) :
LOG . debug ( ' %r ._Receive( %r )' , self , data )
queue . put ( data )
self . AddHandleCB ( _Receive , reply_to , persist = False )
self . _stream . Enqueue ( handle , ( False , ( reply_to , ) + data ) )
self . _stream . Enqueue ( handle , ( reply_to , ) + data )
try :
killed, data = queue . get ( True , deadline )
data = queue . get ( True , deadline )
except Queue . Empty :
self . _stream . Disconnect ( )
raise TimeoutError ( ' deadline exceeded. ' )
if kille d:
if data == _DEAD :
raise StreamError ( ' lost connection during call. ' )
LOG . debug ( ' %r ._EnqueueAwaitReply(): got reply: %r ' , self , data )
return data
def CallWithDeadline ( self , fn, deadline , * args , * * kwargs ) :
LOG . debug ( ' %r .CallWithDeadline( %r , %r , *%r , ** %r ) ' ,
self , fn, deadline , args , kwargs )
def CallWithDeadline ( self , deadline, with_context , fn , * args , * * kwargs ) :
LOG . debug ( ' %r .CallWithDeadline( %r , %r , %r , *%r , ** %r ) ' ,
self , deadline, with_context , fn , args , kwargs )
if isinstance ( fn , types . MethodType ) and \
isinstance ( fn . im_self , ( type , types . ClassType ) ) :
fn_c lass = fn . im_self . __name__
k lass = fn . im_self . __name__
else :
fn_class = None
call = ( fn . __module__ , fn_class , fn . __name__ , args , kwargs )
success , result = self . EnqueueAwaitReply ( CALL_FUNCTION , deadline , call )
klass = None
if success :
return result
else :
exc_obj , traceback = result
exc_obj . real_traceback = traceback
raise exc_obj
call = ( with_context , fn . __module__ , klass , fn . __name__ , args , kwargs )
result = self . EnqueueAwaitReply ( CALL_FUNCTION , deadline , call )
if isinstance ( result , CallError ) :
raise result
return result
def Call ( self , fn , * args , * * kwargs ) :
return self . CallWithDeadline ( fn , None , * args , * * kwargs )
return self . CallWithDeadline ( None , False , fn , * args , * * kwargs )
def __repr__ ( self ) :
bits = map ( repr , filter ( None , [ self . name , self . hostname , self . username ] ) )
@ -668,6 +659,9 @@ class Waker(BasicStream):
self . write_side = Side ( self , wfd )
broker . AddStream ( self )
def __repr__ ( self ) :
return ' <Waker> '
def Wake ( self ) :
os . write ( self . write_side . fd , ' ' )
@ -703,6 +697,9 @@ class IoLogger(BasicStream):
self . write_side = Side ( self , wfd )
self . _broker . AddStream ( self )
def __repr__ ( self ) :
return ' <IoLogger %s fd %d > ' % ( self . _name , self . read_side . fd )
def _LogLines ( self ) :
while self . _buf . find ( ' \n ' ) != - 1 :
line , _ , self . _buf = self . _buf . partition ( ' \n ' )
@ -722,18 +719,20 @@ class Broker(object):
Context broker : this is responsible for keeping track of contexts , any
stream that is associated with them , and for I / O multiplexing .
"""
_waker = None
def __init__ ( self ) :
self . _dead = False
def __init__ ( self , log_level = logging . DEBUG ) :
self . log_level = log_level
self . _alive = True
self . _lock = threading . Lock ( )
self . _stopped = threading . Event ( )
self . _contexts = { }
self . _readers = set ( )
self . _writers = set ( )
self . _waker = None
self . _waker = Waker ( self )
self . _thread = threading . Thread ( target = self . _Loop , name = ' Broker ' )
self . _thread = threading . Thread ( target = self . _BrokerMain ,
name = ' econtext-broker ' )
self . _thread . start ( )
def CreateListener ( self , address = None , backlog = 30 ) :
@ -774,7 +773,7 @@ class Broker(object):
self . _contexts [ context . name ] = context
return context
def GetLocal ( self , name = ' default ' ) :
def GetLocal ( self , name = ' econtext-local ' ) :
"""
Get the named context running on the local machine , creating it if it
does not exist .
@ -799,6 +798,15 @@ class Broker(object):
stream . Connect ( )
return self . Register ( context )
def _CallAndUpdate ( self , stream , func ) :
try :
func ( )
except Exception , e :
LOG . exception ( ' %r crashed ' , stream )
stream . Disconnect ( )
self . _UpdateStream ( stream )
def _LoopOnce ( self ) :
LOG . debug ( ' %r .Loop() ' , self )
#LOG.debug('readers = %r', self._readers)
@ -808,28 +816,24 @@ class Broker(object):
rsides , wsides , _ = select . select ( self . _readers , self . _writers , ( ) )
for side in rsides :
LOG . debug ( ' %r : POLLIN for %r ' , self , side . stream )
side . stream . Receive ( )
self . _UpdateStream ( side . stream )
self . _CallAndUpdate ( side . stream , side . stream . Receive )
for side in wsides :
LOG . debug ( ' %r : POLLOUT for %r ' , self , side . stream )
side . stream . Transmit ( )
self . _UpdateStream ( side . stream )
self . _CallAndUpdate ( side . stream , side . stream . Transmit )
def _ Loop ( self ) :
def _ BrokerMain ( self ) :
"""
Handle stream events until Finalize ( ) is called .
"""
try :
while not self . _dead :
while self . _alive :
self . _LoopOnce ( )
for context in self . _contexts . itervalues ( ) :
stream = context . GetStream ( )
if stream :
stream . Disconnect ( )
self . _stopped . set ( )
except Exception :
LOG . exception ( ' Loop() crashed ' )
@ -837,65 +841,91 @@ class Broker(object):
"""
Wait for the broker to stop .
"""
self . _ stopped. wait ( )
self . _ thread. join ( )
def Finalize ( self ) :
"""
Tell all active streams to disconnect .
"""
self . _ dead = Tru e
self . _ alive = Fals e
self . _waker . Wake ( )
self . Wait ( )
def __repr__ ( self ) :
return ' econtext.Broker(<contexts= %s >) ' % ( self . _contexts . keys ( ) , )
def ExternalContextMain ( context_name , parent_addr , key ) :
syslog . openlog ( ' %s : %s ' % ( getpass . getuser ( ) , context_name ) , syslog . LOG_PID )
syslog . syslog ( ' initializing (parent= %s ) ' % ( os . getenv ( ' SSH_CLIENT ' ) , ) )
return ' Broker() '
class ExternalContext ( object ) :
def _FixupMainModule ( self ) :
global core
sys . modules [ ' econtext ' ] = sys . modules [ ' __main__ ' ]
sys . modules [ ' econtext.core ' ] = sys . modules [ ' __main__ ' ]
core = sys . modules [ ' __main__ ' ]
for klass in globals ( ) . itervalues ( ) :
if hasattr ( klass , ' __module__ ' ) :
klass . __module__ = ' econtext.core '
def _SetupLogging ( self , log_level ) :
logging . basicConfig ( level = log_level )
logging . getLogger ( ' ' ) . handlers [ 0 ] . formatter = Formatter ( False )
def _ReapFirstStage ( self ) :
os . wait ( )
os . dup2 ( 100 , 0 )
os . close ( 100 )
def _SetupMaster ( self , key ) :
self . broker = Broker ( )
self . context = Context ( self . broker , ' parent ' , key = key )
self . channel = Channel ( self . context , CALL_FUNCTION )
self . stream = Stream ( self . context )
self . stream . Accept ( 0 , 1 )
def _SetupImporter ( self ) :
self . importer = SlaveModuleImporter ( self . context )
sys . meta_path . append ( self . importer )
def _SetupStdio ( self ) :
self . stdout_log = IoLogger ( self . broker , ' stdout ' )
self . stderr_log = IoLogger ( self . broker , ' stderr ' )
os . dup2 ( self . stdout_log . write_side . fd , 1 )
os . dup2 ( self . stderr_log . write_side . fd , 2 )
os . close ( 0 )
def _DispatchCalls ( self ) :
for data in self . channel :
LOG . debug ( ' _DispatchCalls( %r ) ' , data )
reply_to , with_context , modname , klass , func , args , kwargs = data
if with_context :
args = ( self , ) + args
logging . basicConfig ( level = logging . INFO )
logging . getLogger ( ' ' ) . handlers [ 0 ] . formatter = Formatter ( False )
LOG . debug ( ' ExternalContextMain( %r , %r , %r ) ' , context_name , parent_addr , key )
os . wait ( ) # Reap the first stage.
os . dup2 ( 100 , 0 )
os . close ( 100 )
broker = Broker ( )
context = Context ( broker , ' parent ' , parent_addr = parent_addr , key = key )
stream = Stream ( context )
channel = Channel ( stream , CALL_FUNCTION )
#stdout_log = IoLogger(broker, 'stdout')
#stderr_log = IoLogger(broker, 'stderr')
stream . Accept ( 0 , 1 )
os . close ( 0 )
os . dup2 ( 2 , 1 )
#os.dup2(stdout_log.write_side.fd, 1)
#os.dup2(stderr_log.write_side.fd, 2)
# stream = context.SetStream(Stream(context))
# stream.
# stream.Connect()
broker . Register ( context )
importer = SlaveModuleImporter ( context )
sys . meta_path . append ( importer )
LOG . debug ( ' start recv ' )
for call_info in channel :
LOG . debug ( ' ExternalContextMain(): CALL_FUNCTION %r ' , call_info )
reply_to , mod_name , class_name , func_name , args , kwargs = call_info
try :
fn = getattr ( __import__ ( mod_name ) , func_name )
stream . Enqueue ( reply_to , ( True , fn ( * args , * * kwargs ) ) )
except Exception , e :
stream . Enqueue ( reply_to , ( False , ( e , traceback . extract_stack ( ) ) ) )
broker . Finalize ( )
LOG . debug ( ' ExternalContextMain exitting ' )
try :
obj = __import__ ( modname )
if klass :
obj = getattr ( obj , klass )
fn = getattr ( obj , func )
self . context . Enqueue ( reply_to , fn ( * args , * * kwargs ) )
except Exception , e :
self . context . Enqueue ( reply_to , CallError ( e ) )
def main ( self , context_name , key , log_level ) :
self . _FixupMainModule ( )
self . _SetupLogging ( log_level )
syslog . openlog ( ' %s : %s ' % ( getpass . getuser ( ) , context_name ) , syslog . LOG_PID )
syslog . syslog ( ' initializing (parent= %s ) ' % ( os . getenv ( ' SSH_CLIENT ' ) , ) )
LOG . debug ( ' ExternalContext.main( %r , %r ) ' , context_name , key )
self . _ReapFirstStage ( )
self . _SetupMaster ( key )
self . _SetupImporter ( )
#self._SetupStdio()
fd = open ( ' /dev/null ' , ' w ' )
os . dup2 ( fd . fileno ( ) , 1 )
os . dup2 ( fd . fileno ( ) , 2 )
self . broker . Register ( self . context )
self . _DispatchCalls ( )
self . broker . Wait ( )
LOG . debug ( ' ExternalContext.main() exitting ' )