CallError rather than trying to preserve exceptions.
Dead sentinel value instead of killed everywhere.
Simplify Channel.
MasterModuleResponder error response.
Simplify Unpickle().
Simplify reprs everywhere.
AllowClass() instead of AllowModule().
Get rid of needless property().
Split ExternalContextMain up into class.
econtext.utils module.
pull/35/head
David Wilson 8 years ago
parent cd9b93dd17
commit da77cb5870

@ -61,11 +61,28 @@ class CorruptMessageError(StreamError):
class TimeoutError(StreamError):
'Raised when a timeout occurs on a stream.'
class CallError(ContextError):
'Raised when .Call() fails'
def __init__(self, e):
name = '%s.%s' % (type(e).__module__, type(e).__name__)
stack = ''.join(traceback.format_stack(sys.exc_info[2]))
ContextError.__init__(self, 'Call failed: %s: %s\n%s', name, e, stack)
#
# Helpers.
#
class Dead(object):
def __eq__(self, other):
return type(other) is Dead
def __repr__(self):
return '<Dead>'
_DEAD = Dead()
def write_all(fd, s):
written = 0
while written < len(s):
@ -117,45 +134,32 @@ class Formatter(logging.Formatter):
class Channel(object):
def __init__(self, stream, handle):
self._context = stream._context
self._stream = stream
def __init__(self, context, handle):
self._context = context
self._handle = handle
self._queue = Queue.Queue()
self._context.AddHandleCB(self._InternalReceive, handle)
self._context.AddHandleCB(self._Receive, handle)
def _InternalReceive(self, killed, data):
def _Receive(self, data):
"""
Callback from the stream object; appends a tuple of
(killed-or-closed, data) to the internal queue and wakes the internal
event.
Args:
# Has the Stream object lost its connection?
killed: bool
data: (
# Has the remote Channel had Close() called?
bool,
# The object passed to the remote Send()
object
)
Callback from the Stream; appends data to the internal queue.
"""
LOG.debug('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue.put((killed or data[0], killed or data[1]))
LOG.debug('%r._Receive(%r)', self, data)
self._queue.put(data)
def Close(self):
"""
Indicate this channel is closed to the remote side.
"""
LOG.debug('%r.Close()', self)
self._stream.Enqueue(handle, (True, None))
self._context.Enqueue(handle, _DEAD)
def Send(self, data):
"""
Send `data` to the remote.
"""
LOG.debug('%r.Send(%r)', self, data)
self._stream.Enqueue(handle, (False, data))
self._context.Enqueue(handle, data)
def Receive(self, timeout=None):
"""
@ -164,12 +168,12 @@ class Channel(object):
"""
LOG.debug('%r.Receive(%r)', self, timeout)
try:
killed, data = self._queue.get(True, timeout)
data = self._queue.get(True, timeout)
except Queue.Empty:
return
LOG.debug('%r.Receive() got killed=%r, data=%r', self, killed, data)
if killed:
LOG.debug('%r.Receive() got %r', self, data)
if data == _DEAD:
raise ChannelError('Channel is closed.')
return data
@ -185,7 +189,7 @@ class Channel(object):
return
def __repr__(self):
return 'econtext.Channel(%r, %r)' % (self._stream, self._handle)
return 'Channel(%r, %r)' % (self._context, self._handle)
class SlaveModuleImporter(object):
@ -212,7 +216,7 @@ class SlaveModuleImporter(object):
if ret is None:
raise ImportError('Master does not have %r' % (fullname,))
kind, path, data = ret
path, data = ret
code = compile(zlib.decompress(data), path, 'exec')
module = imp.new_module(fullname)
sys.modules[fullname] = module
@ -223,30 +227,33 @@ class SlaveModuleImporter(object):
class MasterModuleResponder(object):
def __init__(self, context):
self._context = context
self._context.AddHandleCB(self.GetModule, handle=GET_MODULE)
def GetModule(self, killed, data):
if killed:
def GetModule(self, data):
if data == _DEAD:
return
_, (reply_to, fullname) = data
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
mod = sys.modules.get(fullname)
if mod:
source = zlib.compress(inspect.getsource(mod))
path = os.path.abspath(mod.__file__)
self._context.Enqueue(reply_to, ('source', path, source))
reply_to, fullname = data
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', reply_to, fullname)
try:
module = __import__(fullname)
source = zlib.compress(inspect.getsource(module))
self._context.Enqueue(reply_to, (module.__file__, source))
except Exception, e:
LOG.exception('While importing %r', fullname)
self._context.Enqueue(reply_to, None)
class LogForwarder(object):
def __init__(self, context):
self._context = context
self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG)
def ForwardLog(self, killed, data):
if killed:
def ForwardLog(self, data):
if data == _DEAD:
return
_, (s,) = data
LOG.debug('%r: %s', self._context, s)
LOG.debug('%r: %s', self._context, data)
#
@ -305,7 +312,7 @@ class Stream(BasicStream):
def Pickle(self, obj):
"""
Serialize `obj` using the pickler.
Serialize `obj` into a bytestring.
"""
self._pickler.dump(obj)
data = self._pickler_file.getvalue()
@ -315,15 +322,14 @@ class Stream(BasicStream):
def Unpickle(self, data):
"""
Unserialize `data` into an object using the unpickler.
Deserialize `data` into an object.
"""
LOG.debug('%r.Unpickle(%r)', self, data)
self._unpickler_file.write(data)
self._unpickler_file.truncate(0)
self._unpickler_file.seek(0)
data = self._unpickler.load()
self._unpickler_file.write(data)
self._unpickler_file.seek(0)
self._unpickler_file.truncate(0)
return data
return self._unpickler.load()
def Receive(self):
"""
@ -349,27 +355,28 @@ class Stream(BasicStream):
self._rhmac.update(self._input_buf[20:msg_len+24])
expected_mac = self._rhmac.digest()
if msg_mac != expected_mac:
raise CorruptMessageError('%r got invalid MAC: expected %r, got %r',
raise CorruptMessageError('%r invalid MAC: expected %r, got %r',
self, msg_mac.encode('hex'),
expected_mac.encode('hex'))
try:
handle, data = self.Unpickle(self._input_buf[24:msg_len+24])
self._input_buf = self._input_buf[msg_len+24:]
handle = long(handle)
except (TypeError, ValueError), ex:
raise CorruptMessageError('%r got invalid message: %s', self, ex)
self._input_buf = self._input_buf[msg_len+24:]
self._Invoke(handle, data)
LOG.debug('%r.Receive(): decoded handle=%r; data=%r',
self, handle, data)
def _Invoke(self, handle, data):
LOG.debug('%r._Invoke(): handle=%r; data=%r', self, handle, data)
try:
persist, fn = self._context._handle_map[handle]
if not persist:
del self._context._handle_map[handle]
except KeyError, ex:
raise CorruptMessageError('%r got invalid handle: %r', self, handle)
except (TypeError, ValueError), ex:
raise CorruptMessageError('%r got invalid message: %s', self, ex)
LOG.debug('Calling %r (%r, %r)', fn, False, data)
fn(False, data)
if not persist:
del self._context._handle_map[handle]
fn(data)
def Transmit(self):
"""
@ -424,13 +431,12 @@ class Stream(BasicStream):
self.write_side.fd = None
for handle, (persist, fn) in self._context._handle_map.iteritems():
LOG.debug('%r.Disconnect(): killing %r: %r', self, handle, fn)
fn(True, None)
fn(_DEAD)
def Accept(self, rfd, wfd):
self.read_side = Side(self, os.dup(rfd))
self.write_side = Side(self, os.dup(wfd))
self._context.SetStream(self)
self._context.broker.Register(self._context)
def Connect(self):
"""
@ -444,22 +450,19 @@ class Stream(BasicStream):
self.Enqueue(0, self._context.name)
def __repr__(self):
return 'econtext.%s(<context=%r>)' %\
(self.__class__.__name__, self._context)
return '%s(<context=%r>)' % (self.__class__.__name__, self._context)
class LocalStream(Stream):
"""
Base for streams capable of starting new slaves.
"""
python_path = property(
lambda self: getattr(self, '_python_path', sys.executable),
lambda self, path: setattr(self, '_python_path', path),
doc='The path to the remote Python interpreter.')
#: The path to the remote Python interpreter.
python_path = sys.executable
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_modules = set(['exceptions'])
self._permitted_classes = set([('econtext.core', 'CallError')])
self._unpickler.find_global = self._FindGlobal
def _FindGlobal(self, module_name, class_name):
@ -467,16 +470,16 @@ class LocalStream(Stream):
Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted.
"""
if module_name not in self._permitted_modules:
if (module_name, class_name) not in self._permitted_classes:
raise StreamError('context %r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowModule(self, module_name):
def AllowClass(self, module_name, class_name):
"""
Add `module_name` to the list of permitted modules.
"""
self._permitted_modules.add(module_name)
self._permitted_modules.add((module_name, class_name))
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
@ -517,10 +520,10 @@ class LocalStream(Stream):
self, self.read_side.fd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextMain(%r, %r, %r)\n' % (
source += '\nExternalContext().main(%r, %r, %r)\n' % (
self._context.name,
self._context.broker._listener._listen_addr,
self._context.key
self._context.key,
self._context.broker.log_level,
)
compressed = zlib.compress(source)
@ -530,10 +533,8 @@ class LocalStream(Stream):
class SSHStream(LocalStream):
ssh_path = property(
lambda self: getattr(self, '_ssh_path', 'ssh'),
lambda self, path: setattr(self, '_ssh_path', path),
doc='The path to the SSH binary.')
#: The path to the SSH binary.
ssh_path = 'ssh'
def GetBootCommand(self):
bits = [self.ssh_path]
@ -563,10 +564,7 @@ class Context(object):
self._lock = threading.Lock()
self.responder = MasterModuleResponder(self)
self.AddHandleCB(self.responder.GetModule, handle=GET_MODULE)
self.log_forwarder = LogForwarder(self)
self.AddHandleCB(self.log_forwarder.ForwardLog, handle=FORWARD_LOG)
def GetStream(self):
return self._stream
@ -577,10 +575,7 @@ class Context(object):
def AllocHandle(self):
"""
Allocate a unique handle for this stream.
Returns:
long
Allocate a handle.
"""
self._lock.acquire()
try:
@ -591,8 +586,8 @@ class Context(object):
def AddHandleCB(self, fn, handle, persist=True):
"""
Register `fn(killed, obj)` to run for each `obj` sent to `handle`. If
`persist` is ``False`` then unregister after one delivery.
Register `fn(obj)` to run for each `obj` sent to `handle`. If `persist`
is ``False`` then unregister after one delivery.
"""
LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
@ -613,47 +608,43 @@ class Context(object):
queue = Queue.Queue()
def _Receive(killed, data):
LOG.debug('%r._Receive(%r, %r)', self, killed, data)
queue.put((killed, data))
def _Receive(data):
LOG.debug('%r._Receive(%r)', self, data)
queue.put(data)
self.AddHandleCB(_Receive, reply_to, persist=False)
self._stream.Enqueue(handle, (False, (reply_to,) + data))
self._stream.Enqueue(handle, (reply_to,) + data)
try:
killed, data = queue.get(True, deadline)
data = queue.get(True, deadline)
except Queue.Empty:
self._stream.Disconnect()
raise TimeoutError('deadline exceeded.')
if killed:
if data == _DEAD:
raise StreamError('lost connection during call.')
LOG.debug('%r._EnqueueAwaitReply(): got reply: %r', self, data)
return data
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
LOG.debug('%r.CallWithDeadline(%r, %r, *%r, **%r)',
self, fn, deadline, args, kwargs)
def CallWithDeadline(self, deadline, with_context, fn, *args, **kwargs):
LOG.debug('%r.CallWithDeadline(%r, %r, %r, *%r, **%r)',
self, deadline, with_context, fn, args, kwargs)
if isinstance(fn, types.MethodType) and \
isinstance(fn.im_self, (type, types.ClassType)):
fn_class = fn.im_self.__name__
klass = fn.im_self.__name__
else:
fn_class = None
call = (fn.__module__, fn_class, fn.__name__, args, kwargs)
success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
klass = None
if success:
return result
else:
exc_obj, traceback = result
exc_obj.real_traceback = traceback
raise exc_obj
call = (with_context, fn.__module__, klass, fn.__name__, args, kwargs)
result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
if isinstance(result, CallError):
raise result
return result
def Call(self, fn, *args, **kwargs):
return self.CallWithDeadline(fn, None, *args, **kwargs)
return self.CallWithDeadline(None, False, fn, *args, **kwargs)
def __repr__(self):
bits = map(repr, filter(None, [self.name, self.hostname, self.username]))
@ -668,6 +659,9 @@ class Waker(BasicStream):
self.write_side = Side(self, wfd)
broker.AddStream(self)
def __repr__(self):
return '<Waker>'
def Wake(self):
os.write(self.write_side.fd, ' ')
@ -703,6 +697,9 @@ class IoLogger(BasicStream):
self.write_side = Side(self, wfd)
self._broker.AddStream(self)
def __repr__(self):
return '<IoLogger %s fd %d>' % (self._name, self.read_side.fd)
def _LogLines(self):
while self._buf.find('\n') != -1:
line, _, self._buf = self._buf.partition('\n')
@ -722,18 +719,20 @@ class Broker(object):
Context broker: this is responsible for keeping track of contexts, any
stream that is associated with them, and for I/O multiplexing.
"""
_waker = None
def __init__(self):
self._dead = False
def __init__(self, log_level=logging.DEBUG):
self.log_level = log_level
self._alive = True
self._lock = threading.Lock()
self._stopped = threading.Event()
self._contexts = {}
self._readers = set()
self._writers = set()
self._waker = None
self._waker = Waker(self)
self._thread = threading.Thread(target=self._Loop, name='Broker')
self._thread = threading.Thread(target=self._BrokerMain,
name='econtext-broker')
self._thread.start()
def CreateListener(self, address=None, backlog=30):
@ -774,7 +773,7 @@ class Broker(object):
self._contexts[context.name] = context
return context
def GetLocal(self, name='default'):
def GetLocal(self, name='econtext-local'):
"""
Get the named context running on the local machine, creating it if it
does not exist.
@ -799,6 +798,15 @@ class Broker(object):
stream.Connect()
return self.Register(context)
def _CallAndUpdate(self, stream, func):
try:
func()
except Exception, e:
LOG.exception('%r crashed', stream)
stream.Disconnect()
self._UpdateStream(stream)
def _LoopOnce(self):
LOG.debug('%r.Loop()', self)
#LOG.debug('readers = %r', self._readers)
@ -808,28 +816,24 @@ class Broker(object):
rsides, wsides, _ = select.select(self._readers, self._writers, ())
for side in rsides:
LOG.debug('%r: POLLIN for %r', self, side.stream)
side.stream.Receive()
self._UpdateStream(side.stream)
self._CallAndUpdate(side.stream, side.stream.Receive)
for side in wsides:
LOG.debug('%r: POLLOUT for %r', self, side.stream)
side.stream.Transmit()
self._UpdateStream(side.stream)
self._CallAndUpdate(side.stream, side.stream.Transmit)
def _Loop(self):
def _BrokerMain(self):
"""
Handle stream events until Finalize() is called.
"""
try:
while not self._dead:
while self._alive:
self._LoopOnce()
for context in self._contexts.itervalues():
stream = context.GetStream()
if stream:
stream.Disconnect()
self._stopped.set()
except Exception:
LOG.exception('Loop() crashed')
@ -837,65 +841,91 @@ class Broker(object):
"""
Wait for the broker to stop.
"""
self._stopped.wait()
self._thread.join()
def Finalize(self):
"""
Tell all active streams to disconnect.
"""
self._dead = True
self._alive = False
self._waker.Wake()
self.Wait()
def __repr__(self):
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
def ExternalContextMain(context_name, parent_addr, key):
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
return 'Broker()'
class ExternalContext(object):
def _FixupMainModule(self):
global core
sys.modules['econtext'] = sys.modules['__main__']
sys.modules['econtext.core'] = sys.modules['__main__']
core = sys.modules['__main__']
for klass in globals().itervalues():
if hasattr(klass, '__module__'):
klass.__module__ = 'econtext.core'
def _SetupLogging(self, log_level):
logging.basicConfig(level=log_level)
logging.getLogger('').handlers[0].formatter = Formatter(False)
def _ReapFirstStage(self):
os.wait()
os.dup2(100, 0)
os.close(100)
def _SetupMaster(self, key):
self.broker = Broker()
self.context = Context(self.broker, 'parent', key=key)
self.channel = Channel(self.context, CALL_FUNCTION)
self.stream = Stream(self.context)
self.stream.Accept(0, 1)
def _SetupImporter(self):
self.importer = SlaveModuleImporter(self.context)
sys.meta_path.append(self.importer)
def _SetupStdio(self):
self.stdout_log = IoLogger(self.broker, 'stdout')
self.stderr_log = IoLogger(self.broker, 'stderr')
os.dup2(self.stdout_log.write_side.fd, 1)
os.dup2(self.stderr_log.write_side.fd, 2)
os.close(0)
def _DispatchCalls(self):
for data in self.channel:
LOG.debug('_DispatchCalls(%r)', data)
reply_to, with_context, modname, klass, func, args, kwargs = data
if with_context:
args = (self,) + args
logging.basicConfig(level=logging.INFO)
logging.getLogger('').handlers[0].formatter = Formatter(False)
LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
os.wait() # Reap the first stage.
os.dup2(100, 0)
os.close(100)
broker = Broker()
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = Stream(context)
channel = Channel(stream, CALL_FUNCTION)
#stdout_log = IoLogger(broker, 'stdout')
#stderr_log = IoLogger(broker, 'stderr')
stream.Accept(0, 1)
os.close(0)
os.dup2(2, 1)
#os.dup2(stdout_log.write_side.fd, 1)
#os.dup2(stderr_log.write_side.fd, 2)
# stream = context.SetStream(Stream(context))
# stream.
# stream.Connect()
broker.Register(context)
importer = SlaveModuleImporter(context)
sys.meta_path.append(importer)
LOG.debug('start recv')
for call_info in channel:
LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info)
reply_to, mod_name, class_name, func_name, args, kwargs = call_info
try:
fn = getattr(__import__(mod_name), func_name)
stream.Enqueue(reply_to, (True, fn(*args, **kwargs)))
except Exception, e:
stream.Enqueue(reply_to, (False, (e, traceback.extract_stack())))
broker.Finalize()
LOG.debug('ExternalContextMain exitting')
try:
obj = __import__(modname)
if klass:
obj = getattr(obj, klass)
fn = getattr(obj, func)
self.context.Enqueue(reply_to, fn(*args, **kwargs))
except Exception, e:
self.context.Enqueue(reply_to, CallError(e))
def main(self, context_name, key, log_level):
self._FixupMainModule()
self._SetupLogging(log_level)
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
LOG.debug('ExternalContext.main(%r, %r)', context_name, key)
self._ReapFirstStage()
self._SetupMaster(key)
self._SetupImporter()
#self._SetupStdio()
fd = open('/dev/null', 'w')
os.dup2(fd.fileno(), 1)
os.dup2(fd.fileno(), 2)
self.broker.Register(self.context)
self._DispatchCalls()
self.broker.Wait()
LOG.debug('ExternalContext.main() exitting')

@ -0,0 +1,14 @@
import econtext
def with_broker(func):
def wrapper(*args, **kwargs):
broker = econtext.Broker()
try:
return func(broker, *args, **kwargs)
finally:
broker.Finalize()
wrapper.func_name = func.func_name
return wrapper
Loading…
Cancel
Save