diff --git a/econtext/core.py b/econtext/core.py index f7f016b4..3bf859d1 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -21,7 +21,6 @@ import signal import socket import struct import sys -import syslog import textwrap import threading import traceback @@ -29,51 +28,45 @@ import types import zlib -# -# Module-level data. -# - LOG = logging.getLogger('econtext') +IOLOG = logging.getLogger('econtext.io') +RLOG = logging.getLogger('econtext.ctx') GET_MODULE = 100L CALL_FUNCTION = 101L FORWARD_LOG = 102L -# -# Exceptions. -# - class ContextError(Exception): - 'Raised when a problem occurs with a context.' + """Raised when a problem occurs with a context.""" def __init__(self, fmt, *args): Exception.__init__(self, fmt % args) + class ChannelError(ContextError): - 'Raised when a channel dies or has been closed.' + """Raised when a channel dies or has been closed.""" + class StreamError(ContextError): - 'Raised when a stream cannot be established.' + """Raised when a stream cannot be established.""" + class CorruptMessageError(StreamError): - 'Raised when a corrupt message is received on a stream.' + """Raised when a corrupt message is received on a stream.""" + class TimeoutError(StreamError): - 'Raised when a timeout occurs on a stream.' + """Raised when a timeout occurs on a stream.""" -class CallError(ContextError): - 'Raised when .Call() fails' +class CallError(ContextError): + """Raised when .Call() fails""" def __init__(self, e): name = '%s.%s' % (type(e).__module__, type(e).__name__) stack = ''.join(traceback.format_stack(sys.exc_info[2])) ContextError.__init__(self, 'Call failed: %s: %s\n%s', name, e, stack) -# -# Helpers. -# - class Dead(object): def __eq__(self, other): return type(other) is Dead @@ -94,10 +87,8 @@ def write_all(fd, s): def CreateChild(*args): - """ - Create a child process whose stdin/stdout is connected to a socket, - returning `(pid, socket_obj)`. - """ + """Create a child process whose stdin/stdout is connected to a socket, + returning `(pid, socket_obj)`.""" parentfp, childfp = socket.socketpair() pid = os.fork() if not pid: @@ -114,25 +105,6 @@ def CreateChild(*args): return pid, parentfp -class Formatter(logging.Formatter): - FMT = '%(asctime)s %(levelname).1s %(name)s: %(message)s' - DATEFMT = '%H:%M:%S' - - def __init__(self, parent): - self.parent = parent - super(Formatter, self).__init__(self.FMT, self.DATEFMT) - - def format(self, record): - s = super(Formatter, self).format(record) - if 1: - p = '' - elif self.parent: - p = '\x1b[32m' - else: - p = '\x1b[36m' - return p + ('{%s} %s' % (os.getpid(), s)) - - class Channel(object): def __init__(self, context, handle): self._context = context @@ -141,47 +113,37 @@ class Channel(object): self._context.AddHandleCB(self._Receive, handle) def _Receive(self, data): - """ - Callback from the Stream; appends data to the internal queue. - """ - LOG.debug('%r._Receive(%r)', self, data) + """Callback from the Stream; appends data to the internal queue.""" + IOLOG.debug('%r._Receive(%r)', self, data) self._queue.put(data) def Close(self): - """ - Indicate this channel is closed to the remote side. - """ - LOG.debug('%r.Close()', self) + """Indicate this channel is closed to the remote side.""" + IOLOG.debug('%r.Close()', self) self._context.Enqueue(handle, _DEAD) def Send(self, data): - """ - Send `data` to the remote. - """ - LOG.debug('%r.Send(%r)', self, data) + """Send `data` to the remote.""" + IOLOG.debug('%r.Send(%r)', self, data) self._context.Enqueue(handle, data) def Receive(self, timeout=None): - """ - Receive an object from the remote, or return ``None`` if `timeout` is - reached. - """ - LOG.debug('%r.Receive(timeout=%r)', self, timeout) + """Receive an object from the remote, or return ``None`` if `timeout` + is reached.""" + IOLOG.debug('%r.Receive(timeout=%r)', self, timeout) try: data = self._queue.get(True, timeout) except Queue.Empty: return - LOG.debug('%r.Receive() got %r', self, data) + IOLOG.debug('%r.Receive() got %r', self, data) if data == _DEAD: raise ChannelError('Channel is closed.') return data def __iter__(self): - """ - Return an iterator that yields objects arriving on this channel, until - the channel dies or is closed. - """ + """Iterate objects arriving on this channel, until the channel dies or + is closed.""" while True: try: yield self.Receive() @@ -244,21 +206,37 @@ class MasterModuleResponder(object): self._context.Enqueue(reply_to, None) +class LogHandler(logging.Handler): + def __init__(self, context): + logging.Handler.__init__(self) + self.context = context + self.local = threading.local() + + def emit(self, rec): + if rec.name == 'econtext.io' or \ + getattr(self.local, 'in_commit', False): + return + + self.local.in_commit = True + try: + msg = self.format(rec) + self.context.Enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg)) + finally: + self.local.in_commit = False + + class LogForwarder(object): def __init__(self, context): self._context = context self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG) + self._log = RLOG.getChild(self._context.name) def ForwardLog(self, data): if data == _DEAD: return - LOG.debug('%r: %s', self._context, data) - - -# -# Stream implementations. -# + name, level, s = data + self._log.log(level, '%s: %s', name, s) class Side(object): @@ -300,43 +278,24 @@ class Stream(BasicStream): def __init__(self, context): self._context = context self._lock = threading.Lock() - self._rhmac = hmac.new(context.key, digestmod=sha.new) self._whmac = self._rhmac.copy() - self._pickler_file = cStringIO.StringIO() - self._pickler = cPickle.Pickler(self._pickler_file, protocol=2) - - self._unpickler_file = cStringIO.StringIO() - self._unpickler = cPickle.Unpickler(self._unpickler_file) - - def Pickle(self, obj): - """ - Serialize `obj` into a bytestring. - """ - self._pickler_file.truncate(0) - self._pickler_file.seek(0) - self._pickler.clear_memo() - self._pickler.dump(obj) - return self._pickler_file.getvalue() + _FindGlobal = None def Unpickle(self, data): - """ - Deserialize `data` into an object. - """ - LOG.debug('%r.Unpickle(%r)', self, data) - self._unpickler_file.truncate(0) - self._unpickler_file.seek(0) - self._unpickler_file.write(data) - self._unpickler_file.seek(0) - return self._unpickler.load() + """Deserialize `data` into an object.""" + IOLOG.debug('%r.Unpickle(%r)', self, data) + fp = cStringIO.StringIO(data) + unpickler = cPickle.Unpickler(fp) + if self._FindGlobal: + unpickler.find_global = self._FindGlobal + return unpickler.load() def Receive(self): - """ - Handle the next complete message on the stream. Raise - CorruptMessageError or IOError on failure. - """ - LOG.debug('%r.Receive()', self) + """Handle the next complete message on the stream. Raise + CorruptMessageError or IOError on failure.""" + IOLOG.debug('%r.Receive()', self) buf = os.read(self.read_side.fd, 4096) if not buf: @@ -349,7 +308,7 @@ class Stream(BasicStream): msg_mac = self._input_buf[:20] msg_len = struct.unpack('>L', self._input_buf[20:24])[0] if len(self._input_buf) < msg_len-24: - LOG.debug('Input too short') + IOLOG.debug('Input too short') return self._rhmac.update(self._input_buf[20:msg_len+24]) @@ -379,10 +338,8 @@ class Stream(BasicStream): fn(data) def Transmit(self): - """ - Transmit buffered messages. - """ - LOG.debug('%r.Transmit()', self) + """Transmit buffered messages.""" + IOLOG.debug('%r.Transmit()', self) written = os.write(self.write_side.fd, self._output_buf[:4096]) self._output_buf = self._output_buf[written:] @@ -390,13 +347,11 @@ class Stream(BasicStream): return bool(self._output_buf) def Enqueue(self, handle, obj): - """ - Enqueue `obj` to `handle`, and tell the broker we have output. - """ - LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj) + """Enqueue `obj` to `handle`, and tell the broker we have output.""" + IOLOG.debug('%r.Enqueue(%r, %r)', self, handle, obj) self._lock.acquire() try: - encoded = self.Pickle((handle, obj)) + encoded = cPickle.dumps((handle, obj), protocol=2) msg = struct.pack('>L', len(encoded)) + encoded self._whmac.update(msg) self._output_buf += self._whmac.digest() + msg @@ -405,10 +360,8 @@ class Stream(BasicStream): self._context.broker.UpdateStream(self) def Disconnect(self): - """ - Close our associated file descriptor and tell registered callbacks the - connection has been destroyed. - """ + """Close our associated file descriptor and tell registered callbacks + the connection has been destroyed.""" LOG.debug('%r.Disconnect()', self) if self._context.GetStream() is self: self._context.SetStream(None) @@ -438,9 +391,8 @@ class Stream(BasicStream): self._context.SetStream(self) def Connect(self): - """ - Connect to a Broker at the address specified in our associated Context. - """ + """Connect to a Broker at the address specified in our associated + Context.""" LOG.debug('%r.Connect()', self) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.read_side = Side(self, sock.fileno()) @@ -462,27 +414,22 @@ class LocalStream(Stream): def __init__(self, context): super(LocalStream, self).__init__(context) self._permitted_classes = set([('econtext.core', 'CallError')]) - self._unpickler.find_global = self._FindGlobal def _FindGlobal(self, module_name, class_name): - """ - Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted. - """ + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" if (module_name, class_name) not in self._permitted_classes: - raise StreamError('context %r attempted to unpickle %r in module %r', + raise StreamError('%r attempted to unpickle %r in module %r', self._context, class_name, module_name) return getattr(sys.modules[module_name], class_name) def AllowClass(self, module_name, class_name): - """ - Add `module_name` to the list of permitted modules. - """ + """Add `module_name` to the list of permitted modules.""" self._permitted_modules.add((module_name, class_name)) - # Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe, - # then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced - # with the context name. Optimized for size. + # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a + # pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is + # replaced with the context name. Optimized for size. def _FirstStage(): import os,sys,zlib R,W=os.pipe() @@ -491,7 +438,7 @@ class LocalStream(Stream): os.dup2(R,0) os.close(R) os.close(W) - os.execv(sys.executable,(CONTEXT_NAME,)) + os.execv(sys.executable,('econtext:'+CONTEXT_NAME,)) else: os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) print 'OK' @@ -524,8 +471,8 @@ class LocalStream(Stream): self._context.key, self._context.broker.log_level, ) - compressed = zlib.compress(source) + compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed write_all(self.write_side.fd, preamble) assert os.read(self.read_side.fd, 3) == 'OK\n' @@ -546,11 +493,10 @@ class SSHStream(LocalStream): class Context(object): """ - Represents a remote context regardless of connection method. + Represent a remote context regardless of connection method. """ - - def __init__(self, broker, name=None, hostname=None, username=None, key=None, - parent_addr=None): + def __init__(self, broker, name=None, hostname=None, username=None, + key=None, parent_addr=None): self.broker = broker self.name = name self.hostname = hostname @@ -573,9 +519,7 @@ class Context(object): return stream def AllocHandle(self): - """ - Allocate a handle. - """ + """Allocate a handle.""" self._lock.acquire() try: self._last_handle += 1L @@ -584,23 +528,19 @@ class Context(object): self._lock.release() def AddHandleCB(self, fn, handle, persist=True): - """ - Register `fn(obj)` to run for each `obj` sent to `handle`. If `persist` - is ``False`` then unregister after one delivery. - """ - LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)', - self, fn, handle, persist) + """Register `fn(obj)` to run for each `obj` sent to `handle`. If + `persist` is ``False`` then unregister after one delivery.""" + IOLOG.debug('%r.AddHandleCB(%r, %r, persist=%r)', + self, fn, handle, persist) self._handle_map[handle] = persist, fn def Enqueue(self, handle, obj): self._stream.Enqueue(handle, obj) def EnqueueAwaitReply(self, handle, deadline, data): - """ - Send `data` to `handle` and wait for a response with an optional + """Send `data` to `handle` and wait for a response with an optional timeout. The message contains `(reply_to, data)`, where `reply_to` is - the handle on which this function expects its reply. - """ + the handle on which this function expects its reply.""" reply_to = self.AllocHandle() LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d', self, handle, deadline, data, reply_to) @@ -650,13 +590,21 @@ class Context(object): return 'Context(%s)' % ', '.join(bits) +class ParentContext(Context): + def SetStream(self, stream): + super(ParentContext, self).SetStream(stream) + if stream is None: + LOG.debug('Parent stream is gone, dying.') + self.broker.Finalize(wait=False) + + class Waker(BasicStream): def __init__(self, broker): self._broker = broker rfd, wfd = os.pipe() self.read_side = Side(self, rfd) self.write_side = Side(self, wfd) - broker.AddStream(self) + broker.UpdateStream(self) def __repr__(self): return '' @@ -665,7 +613,6 @@ class Waker(BasicStream): os.write(self.write_side.fd, ' ') def Receive(self): - LOG.debug('%r: waking %r', self, self._broker) os.read(self.read_side.fd, 1) @@ -677,7 +624,7 @@ class Listener(BasicStream): self._sock.listen(backlog) self._listen_addr = self._sock.getsockname() self.read_side = Side(self, self._sock.fileno()) - broker.AddStream(self) + broker.UpdateStream(self) def Receive(self): sock, addr = self._sock.accept() @@ -691,10 +638,12 @@ class IoLogger(BasicStream): def __init__(self, broker, name): self._broker = broker self._name = name + self._log = logging.getLogger(name) rfd, wfd = os.pipe() + self.read_side = Side(self, rfd) self.write_side = Side(self, wfd) - self._broker.AddStream(self) + self._broker.UpdateStream(self) def __repr__(self): return '' % (self._name, self.read_side.fd) @@ -702,9 +651,10 @@ class IoLogger(BasicStream): def _LogLines(self): while self._buf.find('\n') != -1: line, _, self._buf = self._buf.partition('\n') - LOG.debug('%s: %s', self._name, line.rstrip('\n')) + self._log.debug('%s: %s', self._name, line.rstrip('\n')) def Receive(self): + LOG.debug('%r.Receive()', self) buf = os.read(self.read_side.fd, 4096) if not buf: return self.Disconnect() @@ -720,7 +670,7 @@ class Broker(object): """ _waker = None - def __init__(self, log_level=logging.INFO): + def __init__(self, log_level=logging.DEBUG): self.log_level = log_level self._alive = True @@ -735,13 +685,11 @@ class Broker(object): self._thread.start() def CreateListener(self, address=None, backlog=30): - """ - Listen on `address `for connections from newly spawned contexts. - """ + """Listen on `address `for connections from newly spawned contexts.""" self._listener = Listener(self, address, backlog) def _UpdateStream(self, stream): - LOG.debug('_UpdateStream(%r)', stream) + IOLOG.debug('_UpdateStream(%r)', stream) self._lock.acquire() try: if stream.ReadMore() and stream.read_side.fileno(): @@ -757,38 +705,28 @@ class Broker(object): self._lock.release() def UpdateStream(self, stream): - LOG.debug('UpdateStream(%r)', stream) self._UpdateStream(stream) if self._waker: self._waker.Wake() - def AddStream(self, stream): - self.UpdateStream(stream) - def Register(self, context): - """ - Put a context under control of this broker. - """ + """Put a context under control of this broker.""" LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context, context.GetStream().read_side, context.GetStream().write_side) - self.AddStream(context.GetStream()) + self.UpdateStream(context.GetStream()) self._contexts[context.name] = context return context - def GetLocal(self, name='econtext-local'): - """ - Get the named context running on the local machine, creating it if it - does not exist. - """ + def GetLocal(self, name='default'): + """Get the named context running on the local machine, creating it if + it does not exist.""" context = Context(self, name) context.SetStream(LocalStream(context)).Connect() return self.Register(context) def GetRemote(self, hostname, username, name=None, python_path=None): - """ - Get the named remote context, creating it if it does not exist. - """ + """Get the named remote context, creating it if it does not exist.""" if name is None: name = 'econtext[%s@%s:%d]' %\ (username, socket.gethostname(), os.getpid()) @@ -810,24 +748,20 @@ class Broker(object): self._UpdateStream(stream) def _LoopOnce(self): - LOG.debug('%r.Loop()', self) - #LOG.debug('readers = %r', self._readers) - #LOG.debug('rfds = %r', [r.fileno() for r in self._readers]) - #LOG.debug('writers = %r', self._writers) - #LOG.debug('wfds = %r', [w.fileno() for w in self._writers]) + IOLOG.debug('%r.Loop()', self) + IOLOG.debug('readers = %r', [(r.fileno(), r) for r in self._readers]) + IOLOG.debug('writers = %r', [(w.fileno(), w) for w in self._writers]) rsides, wsides, _ = select.select(self._readers, self._writers, ()) for side in rsides: - LOG.debug('%r: POLLIN for %r', self, side.stream) + IOLOG.debug('%r: POLLIN for %r', self, side.stream) self._CallAndUpdate(side.stream, side.stream.Receive) for side in wsides: - LOG.debug('%r: POLLOUT for %r', self, side.stream) + IOLOG.debug('%r: POLLOUT for %r', self, side.stream) self._CallAndUpdate(side.stream, side.stream.Transmit) def _BrokerMain(self): - """ - Handle stream events until Finalize() is called. - """ + """Handle events until Finalize() is called.""" try: while self._alive: self._LoopOnce() @@ -840,18 +774,15 @@ class Broker(object): LOG.exception('Loop() crashed') def Wait(self): - """ - Wait for the broker to stop. - """ + """Wait for the broker to stop.""" self._thread.join() - def Finalize(self): - """ - Tell all active streams to disconnect. - """ + def Finalize(self, wait=True): + """Disconect all streams and wait for broker to stop.""" self._alive = False self._waker.Wake() - self.Wait() + if wait: + self.Wait() def __repr__(self): return 'Broker()' @@ -868,10 +799,6 @@ class ExternalContext(object): if hasattr(klass, '__module__'): klass.__module__ = 'econtext.core' - def _SetupLogging(self, log_level): - logging.basicConfig(level=log_level, filename='slave.txt') - logging.getLogger('').handlers[0].formatter = Formatter(False) - def _ReapFirstStage(self): os.wait() os.dup2(100, 0) @@ -879,11 +806,20 @@ class ExternalContext(object): def _SetupMaster(self, key): self.broker = Broker() - self.context = Context(self.broker, 'parent', key=key) + self.context = ParentContext(self.broker, 'parent', key=key) self.channel = Channel(self.context, CALL_FUNCTION) self.stream = Stream(self.context) self.stream.Accept(0, 1) + def _SetupLogging(self, log_level): + logging.basicConfig(level=log_level, stream=open('slave', 'w', 1)) + return + logging.basicConfig(level=log_level) + root = logging.getLogger() + root.setLevel(log_level) + root.handlers = [LogHandler(self.context)] + LOG.info('Connected to %s', self.context) + def _SetupImporter(self): self.importer = SlaveModuleImporter(self.context) sys.meta_path.append(self.importer) @@ -893,9 +829,17 @@ class ExternalContext(object): self.stderr_log = IoLogger(self.broker, 'stderr') os.dup2(self.stdout_log.write_side.fd, 1) os.dup2(self.stderr_log.write_side.fd, 2) - os.close(0) + + fp = file('/dev/null') + try: + os.dup2(fp.fileno(), 0) + finally: + fp.close() def _DispatchCalls(self): + #signal.alarm(10) + signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize()) + for data in self.channel: LOG.debug('_DispatchCalls(%r)', data) reply_to, with_context, modname, klass, func, args, kwargs = data @@ -912,24 +856,16 @@ class ExternalContext(object): self.context.Enqueue(reply_to, CallError(e)) def main(self, context_name, key, log_level): - self._FixupMainModule() - self._SetupLogging(log_level) - - syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) - syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) - LOG.debug('ExternalContext.main(%r, %r)', context_name, key) - + import stack self._ReapFirstStage() + self._FixupMainModule() self._SetupMaster(key) + self._SetupLogging(log_level) self._SetupImporter() - #self._SetupStdio() - os.dup2(2, 1) - if 0: - fd = open('/dev/null', 'w') - os.dup2(fd.fileno(), 1) - os.dup2(fd.fileno(), 2) + self._SetupStdio() self.broker.Register(self.context) + LOG.info('entering dispatchcalls') self._DispatchCalls() self.broker.Wait() LOG.debug('ExternalContext.main() exitting')