* Tidy up docstrings.
* Start work on IoLogger and LogForwarder.
* Fix race in Stream.Accept / AddHandleCb CALL_FUNCTION
* Fix stdout slave stream corruption.
* Fix HOSTNAME vs. socket.gethostname().
* Remove silly locking.
* Move handle management to Context for later reconnects.
pull/35/head
David Wilson 8 years ago
parent 07ba2de7b0
commit 89e282734c

@ -1,8 +1,8 @@
#!/usr/bin/env python2.5
'''
Python External Execution Contexts.
'''
"""
Python external execution contexts.
"""
import Queue
import cPickle
@ -37,6 +37,7 @@ LOG = logging.getLogger('econtext')
GET_MODULE = 100L
CALL_FUNCTION = 101L
FORWARD_LOG = 102L
#
@ -76,15 +77,10 @@ def write_all(fd, s):
def CreateChild(*args):
'''
Create a child process whose stdin/stdout is connected to a socket.
Args:
*args: executable name and process arguments.
Returns:
pid, sock
'''
"""
Create a child process whose stdin/stdout is connected to a socket,
returning `(pid, socket_obj)`.
"""
parentfp, childfp = socket.socketpair()
pid = os.fork()
if not pid:
@ -122,13 +118,14 @@ class Formatter(logging.Formatter):
class Channel(object):
def __init__(self, stream, handle):
self._context = stream._context
self._stream = stream
self._handle = handle
self._queue = Queue.Queue()
self._stream.AddHandleCB(self._InternalReceive, handle)
self._context.AddHandleCB(self._InternalReceive, handle)
def _InternalReceive(self, killed, data):
'''
"""
Callback from the stream object; appends a tuple of
(killed-or-closed, data) to the internal queue and wakes the internal
event.
@ -142,35 +139,29 @@ class Channel(object):
# The object passed to the remote Send()
object
)
'''
"""
LOG.debug('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue.put((killed or data[0], killed or data[1]))
def Close(self):
'''
"""
Indicate this channel is closed to the remote side.
'''
"""
LOG.debug('%r.Close()', self)
self._stream.Enqueue(handle, (True, None))
def Send(self, data):
'''
Send the given object to the remote side.
'''
"""
Send `data` to the remote.
"""
LOG.debug('%r.Send(%r)', self, data)
self._stream.Enqueue(handle, (False, data))
def Receive(self, timeout=None):
'''
Receive the next object to arrive on this channel, or return if the
optional timeout is reached.
Args:
timeout: float
Returns:
object
'''
"""
Receive an object from the remote, or return ``None`` if `timeout` is
reached.
"""
LOG.debug('%r.Receive(%r)', self, timeout)
try:
killed, data = self._queue.get(True, timeout)
@ -183,10 +174,10 @@ class Channel(object):
return data
def __iter__(self):
'''
"""
Return an iterator that yields objects arriving on this channel, until
the channel dies or is closed.
'''
"""
while True:
try:
yield self.Receive()
@ -198,18 +189,13 @@ class Channel(object):
class SlaveModuleImporter(object):
'''
"""
Import protocol implementation that fetches modules from the parent
process.
'''
:param context: Context to communicate via.
"""
def __init__(self, context):
'''
Initialise a new instance.
Args:
context: Context instance this importer will communicate via.
'''
self._context = context
def find_module(self, fullname, path=None):
@ -235,20 +221,32 @@ class SlaveModuleImporter(object):
class MasterModuleResponder(object):
def __init__(self, stream):
self._stream = stream
def __init__(self, context):
self._context = context
def GetModule(self, killed, data):
if killed:
return
_, (reply_handle, fullname) = data
_, (reply_to, fullname) = data
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
mod = sys.modules.get(fullname)
if mod:
source = zlib.compress(inspect.getsource(mod))
path = os.path.abspath(mod.__file__)
self._stream.Enqueue(reply_handle, ('source', path, source))
self._context.Enqueue(reply_to, ('source', path, source))
class LogForwarder(object):
def __init__(self, context):
self._context = context
def ForwardLog(self, killed, data):
if killed:
return
_, (s,) = data
LOG.debug('%r: %s', self._context, s)
#
@ -284,25 +282,21 @@ class BasicStream(object):
class Stream(BasicStream):
"""
Initialize a new Stream instance.
:param context: Context to communicate with.
"""
_input_buf = ''
_output_buf = ''
def __init__(self, context):
'''
Initialize a new Stream instance.
Args:
context: econtext.Context
'''
self._context = context
self._lock = threading.Lock()
self._rhmac = hmac.new(context.key, digestmod=sha.new)
self._whmac = self._rhmac.copy()
self._last_handle = 1000L
self._handle_map = {}
self._pickler_file = cStringIO.StringIO()
self._pickler = cPickle.Pickler(self._pickler_file, protocol=2)
@ -310,15 +304,9 @@ class Stream(BasicStream):
self._unpickler = cPickle.Unpickler(self._unpickler_file)
def Pickle(self, obj):
'''
Serialize the given object using the pickler.
Args:
obj: object
Returns:
str
'''
"""
Serialize `obj` using the pickler.
"""
self._pickler.dump(obj)
data = self._pickler_file.getvalue()
self._pickler_file.seek(0)
@ -326,15 +314,9 @@ class Stream(BasicStream):
return data
def Unpickle(self, data):
'''
Unserialize the given string using the unpickler.
Args:
data: str
Returns:
object
'''
"""
Unserialize `data` into an object using the unpickler.
"""
LOG.debug('%r.Unpickle(%r)', self, data)
self._unpickler_file.write(data)
self._unpickler_file.seek(0)
@ -343,38 +325,11 @@ class Stream(BasicStream):
self._unpickler_file.truncate(0)
return data
def AllocHandle(self):
'''
Allocate a unique handle for this stream.
Returns:
long
'''
self._lock.acquire()
try:
self._last_handle += 1L
return self._last_handle
finally:
self._lock.release()
def AddHandleCB(self, fn, handle, persist=True):
'''
Invoke a function for all messages with the given handle.
Args:
fn: callable
handle: long
persist: False to only receive a single message.
'''
LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
self._handle_map[handle] = persist, fn
def Receive(self):
'''
"""
Handle the next complete message on the stream. Raise
CorruptMessageError or IOError on failure.
'''
"""
LOG.debug('%r.Receive()', self)
buf = os.read(self.read_side.fd, 4096)
@ -405,9 +360,9 @@ class Stream(BasicStream):
LOG.debug('%r.Receive(): decoded handle=%r; data=%r',
self, handle, data)
persist, fn = self._handle_map[handle]
persist, fn = self._context._handle_map[handle]
if not persist:
del self._handle_map[handle]
del self._context._handle_map[handle]
except KeyError, ex:
raise CorruptMessageError('%r got invalid handle: %r', self, handle)
except (TypeError, ValueError), ex:
@ -417,15 +372,9 @@ class Stream(BasicStream):
fn(False, data)
def Transmit(self):
'''
"""
Transmit buffered messages.
Returns:
bool: more data left in bufer?
Raises:
IOError
'''
"""
LOG.debug('%r.Transmit()', self)
written = os.write(self.write_side.fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
@ -434,31 +383,26 @@ class Stream(BasicStream):
return bool(self._output_buf)
def Enqueue(self, handle, obj):
'''
Serialize an object, send it to the given handle, and tell our context's
broker we have output.
Args:
handle: long
obj: object
'''
"""
Enqueue `obj` to `handle`, and tell the broker we have output.
"""
LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj)
encoded = self.Pickle((handle, obj))
msg = struct.pack('>L', len(encoded)) + encoded
self._lock.acquire()
try:
encoded = self.Pickle((handle, obj))
msg = struct.pack('>L', len(encoded)) + encoded
self._whmac.update(msg)
self._output_buf += self._whmac.digest() + msg
finally:
self._lock.release()
self._context.broker.UpdateStream(self, wake=True)
self._context.broker.UpdateStream(self)
def Disconnect(self):
'''
"""
Close our associated file descriptor and tell registered callbacks the
connection has been destroyed.
'''
"""
LOG.debug('%r.Disconnect()', self)
if self._context.GetStream() is self:
self._context.SetStream(None)
@ -478,28 +422,20 @@ class Stream(BasicStream):
self.read_side.fd = None
self.write_side.fd = None
for handle, (persist, fn) in self._handle_map.iteritems():
LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r',
self, handle, fn)
for handle, (persist, fn) in self._context._handle_map.iteritems():
LOG.debug('%r.Disconnect(): killing %r: %r', self, handle, fn)
fn(True, None)
@classmethod
def Accept(cls, context, rfd, wfd):
'''
'''
stream = cls(context)
stream.read_side = Side(stream, os.dup(rfd))
stream.write_side = Side(stream, os.dup(wfd))
context.SetStream(stream)
context.broker.Register(context)
return stream
def Accept(self, rfd, wfd):
self.read_side = Side(self, os.dup(rfd))
self.write_side = Side(self, os.dup(wfd))
self._context.SetStream(self)
self._context.broker.Register(self._context)
def Connect(self):
'''
"""
Connect to a Broker at the address specified in our associated Context.
'''
"""
LOG.debug('%r.Connect()', self)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.read_side = Side(self, sock.fileno())
@ -513,10 +449,9 @@ class Stream(BasicStream):
class LocalStream(Stream):
'''
"""
Base for streams capable of starting new slaves.
'''
"""
python_path = property(
lambda self: getattr(self, '_python_path', sys.executable),
lambda self, path: setattr(self, '_python_path', path),
@ -527,33 +462,20 @@ class LocalStream(Stream):
self._permitted_modules = set(['exceptions'])
self._unpickler.find_global = self._FindGlobal
self.responder = MasterModuleResponder(self)
self.AddHandleCB(self.responder.GetModule, handle=GET_MODULE)
def _FindGlobal(self, module_name, class_name):
'''
See the cPickle documentation: given a module and class name, determine
whether class referred to is safe for unpickling.
Args:
module_name: str
class_name: str
Returns:
classobj or type
'''
"""
Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted.
"""
if module_name not in self._permitted_modules:
raise StreamError('context %r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowModule(self, module_name):
'''
Add the given module to the list of permitted modules.
Args:
module_name: str
'''
"""
Add `module_name` to the list of permitted modules.
"""
self._permitted_modules.add(module_name)
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
@ -618,13 +540,14 @@ class SSHStream(LocalStream):
if self._context.username:
bits += ['-l', self._context.username]
bits.append(self._context.hostname)
return bits + map(commands.mkarg, super(SSHStream, self).GetBootCommand())
base = super(SSHStream, self).GetBootCommand()
return bits + map(commands.mkarg, base)
class Context(object):
'''
"""
Represents a remote context regardless of connection method.
'''
"""
def __init__(self, broker, name=None, hostname=None, username=None, key=None,
parent_addr=None):
@ -635,6 +558,16 @@ class Context(object):
self.parent_addr = parent_addr
self.key = key or ('%016x' % random.getrandbits(128))
self._last_handle = 1000L
self._handle_map = {}
self._lock = threading.Lock()
self.responder = MasterModuleResponder(self)
self.AddHandleCB(self.responder.GetModule, handle=GET_MODULE)
self.log_forwarder = LogForwarder(self)
self.AddHandleCB(self.log_forwarder.ForwardLog, handle=FORWARD_LOG)
def GetStream(self):
return self._stream
@ -642,15 +575,41 @@ class Context(object):
self._stream = stream
return stream
def AllocHandle(self):
"""
Allocate a unique handle for this stream.
Returns:
long
"""
self._lock.acquire()
try:
self._last_handle += 1L
return self._last_handle
finally:
self._lock.release()
def AddHandleCB(self, fn, handle, persist=True):
"""
Register `fn(killed, obj)` to run for each `obj` sent to `handle`. If
`persist` is ``False`` then unregister after one delivery.
"""
LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
self._handle_map[handle] = persist, fn
def Enqueue(self, handle, obj):
self._stream.Enqueue(handle, obj)
def EnqueueAwaitReply(self, handle, deadline, data):
'''
Send a message to the given handle and wait for a response with an
optional timeout. The message contains (reply_handle, data), where
reply_handle is the handle on which this function expects its reply.
'''
reply_handle = self._stream.AllocHandle()
"""
Send `data` to `handle` and wait for a response with an optional
timeout. The message contains `(reply_to, data)`, where `reply_to` is
the handle on which this function expects its reply.
"""
reply_to = self.AllocHandle()
LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d',
self, handle, deadline, data, reply_handle)
self, handle, deadline, data, reply_to)
queue = Queue.Queue()
@ -658,8 +617,8 @@ class Context(object):
LOG.debug('%r._Receive(%r, %r)', self, killed, data)
queue.put((killed, data))
self._stream.AddHandleCB(_Receive, reply_handle, persist=False)
self._stream.Enqueue(handle, (False, (reply_handle,) + data))
self.AddHandleCB(_Receive, reply_to, persist=False)
self._stream.Enqueue(handle, (False, (reply_to,) + data))
try:
killed, data = queue.get(True, deadline)
@ -730,18 +689,44 @@ class Listener(BasicStream):
def Receive(self):
sock, addr = self._sock.accept()
context = Context(self._broker, name=addr)
Stream.Accept(context, sock.fileno())
Stream(context).Accept(sock.fileno(), sock.fileno())
class IoLogger(BasicStream):
_buf = ''
def __init__(self, broker, name):
self._broker = broker
self._name = name
rfd, wfd = os.pipe()
self.read_side = Side(self, rfd)
self.write_side = Side(self, wfd)
self._broker.AddStream(self)
def _LogLines(self):
while self._buf.find('\n') != -1:
line, _, self._buf = self._buf.partition('\n')
LOG.debug('%s: %s', self._name, line.rstrip('\n'))
def Receive(self):
buf = os.read(self.read_side.fd, 4096)
if not buf:
return self.Disconnect()
self._buf += buf
self._LogLines()
class Broker(object):
'''
"""
Context broker: this is responsible for keeping track of contexts, any
stream that is associated with them, and for I/O multiplexing.
'''
"""
def __init__(self):
self._dead = False
self._lock = threading.Lock()
self._stopped = threading.Event()
self._contexts = {}
self._readers = set()
self._writers = set()
@ -752,16 +737,13 @@ class Broker(object):
self._thread.start()
def CreateListener(self, address=None, backlog=30):
'''
Create a socket to accept connections from newly spawned contexts.
Args:
address: The IPv4 address tuple to listen on.
backlog: Number of connections to accept while broker thread is busy.
'''
"""
Listen on `address `for connections from newly spawned contexts.
"""
self._listener = Listener(self, address, backlog)
def UpdateStream(self, stream, wake=False):
LOG.debug('UpdateStream(%r, wake=%s)', stream, wake)
def _UpdateStream(self, stream):
LOG.debug('_UpdateStream(%r)', stream)
if stream.ReadMore() and stream.read_side.fileno():
self._readers.add(stream.read_side)
else:
@ -772,22 +754,19 @@ class Broker(object):
else:
self._writers.discard(stream.write_side)
if wake:
def UpdateStream(self, stream):
LOG.debug('UpdateStream(%r)', stream)
self._UpdateStream(stream)
if self._waker:
self._waker.Wake()
def AddStream(self, stream):
self._lock.acquire()
try:
if self._waker:
self._waker.Wake()
self.UpdateStream(stream)
finally:
self._lock.release()
self.UpdateStream(stream)
def Register(self, context):
'''
"""
Put a context under control of this broker.
'''
"""
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
context.GetStream().read_side,
context.GetStream().write_side)
@ -795,26 +774,22 @@ class Broker(object):
self._contexts[context.name] = context
return context
def GetLocal(self, name):
'''
Return the named local context, or create it if it doesn't exist.
Args:
name: 'my-local-context'
Returns:
econtext.Context
'''
def GetLocal(self, name='default'):
"""
Get the named context running on the local machine, creating it if it
does not exist.
"""
context = Context(self, name)
context.SetStream(LocalStream(context)).Connect()
return self.Register(context)
def GetRemote(self, hostname, username, name=None, python_path=None):
'''
Return the named remote context, or create it if it doesn't exist.
'''
"""
Get the named remote context, creating it if it does not exist.
"""
if name is None:
name = 'econtext[%s@%s:%d]' %\
(username, os.getenv('HOSTNAME'), os.getpid())
(username, socket.gethostname(), os.getpid())
context = Context(self, name, hostname, username)
stream = SSHStream(context)
@ -826,9 +801,6 @@ class Broker(object):
def _LoopOnce(self):
LOG.debug('%r.Loop()', self)
self._lock.acquire()
self._lock.release()
#LOG.debug('readers = %r', self._readers)
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
#LOG.debug('writers = %r', self._writers)
@ -837,17 +809,17 @@ class Broker(object):
for side in rsides:
LOG.debug('%r: POLLIN for %r', self, side.stream)
side.stream.Receive()
self.UpdateStream(side.stream)
self._UpdateStream(side.stream)
for side in wsides:
LOG.debug('%r: POLLOUT for %r', self, side.stream)
side.stream.Transmit()
self.UpdateStream(side.stream)
self._UpdateStream(side.stream)
def _Loop(self):
'''
"""
Handle stream events until Finalize() is called.
'''
"""
try:
while not self._dead:
self._LoopOnce()
@ -856,17 +828,24 @@ class Broker(object):
stream = context.GetStream()
if stream:
stream.Disconnect()
self._stopped.set()
except Exception:
LOG.exception('Loop() crashed')
def Wait(self):
"""
Wait for the broker to stop.
"""
self._stopped.wait()
def Finalize(self):
'''
"""
Tell all active streams to disconnect.
'''
"""
self._dead = True
self._waker.Wake()
self._lock.acquire()
self._lock.release()
self.Wait()
def __repr__(self):
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
@ -887,9 +866,17 @@ def ExternalContextMain(context_name, parent_addr, key):
broker = Broker()
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = Stream.Accept(context, 0, 1)
stream = Stream(context)
channel = Channel(stream, CALL_FUNCTION)
#stdout_log = IoLogger(broker, 'stdout')
#stderr_log = IoLogger(broker, 'stderr')
stream.Accept(0, 1)
os.close(0)
os.close(1)
os.dup2(2, 1)
#os.dup2(stdout_log.write_side.fd, 1)
#os.dup2(stderr_log.write_side.fd, 2)
# stream = context.SetStream(Stream(context))
# stream.
@ -899,15 +886,16 @@ def ExternalContextMain(context_name, parent_addr, key):
importer = SlaveModuleImporter(context)
sys.meta_path.append(importer)
for call_info in Channel(stream, CALL_FUNCTION):
LOG.debug('start recv')
for call_info in channel:
LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info)
reply_handle, mod_name, class_name, func_name, args, kwargs = call_info
reply_to, mod_name, class_name, func_name, args, kwargs = call_info
try:
fn = getattr(__import__(mod_name), func_name)
stream.Enqueue(reply_handle, (True, fn(*args, **kwargs)))
stream.Enqueue(reply_to, (True, fn(*args, **kwargs)))
except Exception, e:
stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
stream.Enqueue(reply_to, (False, (e, traceback.extract_stack())))
broker.Finalize()
LOG.error('ExternalContextMain exitting')

Loading…
Cancel
Save