From 0b0e828c04e005d982fd37ff92af66ecbdcb9deb Mon Sep 17 00:00:00 2001 From: David Wilson Date: Wed, 10 Aug 2016 19:58:49 +0100 Subject: [PATCH] Fixes/improvements Get rid of syslog. Get rid of section comments. Move IOLOG to separate logger to avoid infinite loop. Change function docstring style. Delete log Formatter. Implement LogHandler to forward logs to parent. Delete Pickle(), simplify Unpickle(). Have slave Finalize() when parent disconnects. Delete AddStream(). --- econtext/core.py | 360 +++++++++++++++++++---------------------------- 1 file changed, 148 insertions(+), 212 deletions(-) diff --git a/econtext/core.py b/econtext/core.py index f7f016b4..3bf859d1 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -21,7 +21,6 @@ import signal import socket import struct import sys -import syslog import textwrap import threading import traceback @@ -29,51 +28,45 @@ import types import zlib -# -# Module-level data. -# - LOG = logging.getLogger('econtext') +IOLOG = logging.getLogger('econtext.io') +RLOG = logging.getLogger('econtext.ctx') GET_MODULE = 100L CALL_FUNCTION = 101L FORWARD_LOG = 102L -# -# Exceptions. -# - class ContextError(Exception): - 'Raised when a problem occurs with a context.' + """Raised when a problem occurs with a context.""" def __init__(self, fmt, *args): Exception.__init__(self, fmt % args) + class ChannelError(ContextError): - 'Raised when a channel dies or has been closed.' + """Raised when a channel dies or has been closed.""" + class StreamError(ContextError): - 'Raised when a stream cannot be established.' + """Raised when a stream cannot be established.""" + class CorruptMessageError(StreamError): - 'Raised when a corrupt message is received on a stream.' + """Raised when a corrupt message is received on a stream.""" + class TimeoutError(StreamError): - 'Raised when a timeout occurs on a stream.' + """Raised when a timeout occurs on a stream.""" -class CallError(ContextError): - 'Raised when .Call() fails' +class CallError(ContextError): + """Raised when .Call() fails""" def __init__(self, e): name = '%s.%s' % (type(e).__module__, type(e).__name__) stack = ''.join(traceback.format_stack(sys.exc_info[2])) ContextError.__init__(self, 'Call failed: %s: %s\n%s', name, e, stack) -# -# Helpers. -# - class Dead(object): def __eq__(self, other): return type(other) is Dead @@ -94,10 +87,8 @@ def write_all(fd, s): def CreateChild(*args): - """ - Create a child process whose stdin/stdout is connected to a socket, - returning `(pid, socket_obj)`. - """ + """Create a child process whose stdin/stdout is connected to a socket, + returning `(pid, socket_obj)`.""" parentfp, childfp = socket.socketpair() pid = os.fork() if not pid: @@ -114,25 +105,6 @@ def CreateChild(*args): return pid, parentfp -class Formatter(logging.Formatter): - FMT = '%(asctime)s %(levelname).1s %(name)s: %(message)s' - DATEFMT = '%H:%M:%S' - - def __init__(self, parent): - self.parent = parent - super(Formatter, self).__init__(self.FMT, self.DATEFMT) - - def format(self, record): - s = super(Formatter, self).format(record) - if 1: - p = '' - elif self.parent: - p = '\x1b[32m' - else: - p = '\x1b[36m' - return p + ('{%s} %s' % (os.getpid(), s)) - - class Channel(object): def __init__(self, context, handle): self._context = context @@ -141,47 +113,37 @@ class Channel(object): self._context.AddHandleCB(self._Receive, handle) def _Receive(self, data): - """ - Callback from the Stream; appends data to the internal queue. - """ - LOG.debug('%r._Receive(%r)', self, data) + """Callback from the Stream; appends data to the internal queue.""" + IOLOG.debug('%r._Receive(%r)', self, data) self._queue.put(data) def Close(self): - """ - Indicate this channel is closed to the remote side. - """ - LOG.debug('%r.Close()', self) + """Indicate this channel is closed to the remote side.""" + IOLOG.debug('%r.Close()', self) self._context.Enqueue(handle, _DEAD) def Send(self, data): - """ - Send `data` to the remote. - """ - LOG.debug('%r.Send(%r)', self, data) + """Send `data` to the remote.""" + IOLOG.debug('%r.Send(%r)', self, data) self._context.Enqueue(handle, data) def Receive(self, timeout=None): - """ - Receive an object from the remote, or return ``None`` if `timeout` is - reached. - """ - LOG.debug('%r.Receive(timeout=%r)', self, timeout) + """Receive an object from the remote, or return ``None`` if `timeout` + is reached.""" + IOLOG.debug('%r.Receive(timeout=%r)', self, timeout) try: data = self._queue.get(True, timeout) except Queue.Empty: return - LOG.debug('%r.Receive() got %r', self, data) + IOLOG.debug('%r.Receive() got %r', self, data) if data == _DEAD: raise ChannelError('Channel is closed.') return data def __iter__(self): - """ - Return an iterator that yields objects arriving on this channel, until - the channel dies or is closed. - """ + """Iterate objects arriving on this channel, until the channel dies or + is closed.""" while True: try: yield self.Receive() @@ -244,21 +206,37 @@ class MasterModuleResponder(object): self._context.Enqueue(reply_to, None) +class LogHandler(logging.Handler): + def __init__(self, context): + logging.Handler.__init__(self) + self.context = context + self.local = threading.local() + + def emit(self, rec): + if rec.name == 'econtext.io' or \ + getattr(self.local, 'in_commit', False): + return + + self.local.in_commit = True + try: + msg = self.format(rec) + self.context.Enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg)) + finally: + self.local.in_commit = False + + class LogForwarder(object): def __init__(self, context): self._context = context self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG) + self._log = RLOG.getChild(self._context.name) def ForwardLog(self, data): if data == _DEAD: return - LOG.debug('%r: %s', self._context, data) - - -# -# Stream implementations. -# + name, level, s = data + self._log.log(level, '%s: %s', name, s) class Side(object): @@ -300,43 +278,24 @@ class Stream(BasicStream): def __init__(self, context): self._context = context self._lock = threading.Lock() - self._rhmac = hmac.new(context.key, digestmod=sha.new) self._whmac = self._rhmac.copy() - self._pickler_file = cStringIO.StringIO() - self._pickler = cPickle.Pickler(self._pickler_file, protocol=2) - - self._unpickler_file = cStringIO.StringIO() - self._unpickler = cPickle.Unpickler(self._unpickler_file) - - def Pickle(self, obj): - """ - Serialize `obj` into a bytestring. - """ - self._pickler_file.truncate(0) - self._pickler_file.seek(0) - self._pickler.clear_memo() - self._pickler.dump(obj) - return self._pickler_file.getvalue() + _FindGlobal = None def Unpickle(self, data): - """ - Deserialize `data` into an object. - """ - LOG.debug('%r.Unpickle(%r)', self, data) - self._unpickler_file.truncate(0) - self._unpickler_file.seek(0) - self._unpickler_file.write(data) - self._unpickler_file.seek(0) - return self._unpickler.load() + """Deserialize `data` into an object.""" + IOLOG.debug('%r.Unpickle(%r)', self, data) + fp = cStringIO.StringIO(data) + unpickler = cPickle.Unpickler(fp) + if self._FindGlobal: + unpickler.find_global = self._FindGlobal + return unpickler.load() def Receive(self): - """ - Handle the next complete message on the stream. Raise - CorruptMessageError or IOError on failure. - """ - LOG.debug('%r.Receive()', self) + """Handle the next complete message on the stream. Raise + CorruptMessageError or IOError on failure.""" + IOLOG.debug('%r.Receive()', self) buf = os.read(self.read_side.fd, 4096) if not buf: @@ -349,7 +308,7 @@ class Stream(BasicStream): msg_mac = self._input_buf[:20] msg_len = struct.unpack('>L', self._input_buf[20:24])[0] if len(self._input_buf) < msg_len-24: - LOG.debug('Input too short') + IOLOG.debug('Input too short') return self._rhmac.update(self._input_buf[20:msg_len+24]) @@ -379,10 +338,8 @@ class Stream(BasicStream): fn(data) def Transmit(self): - """ - Transmit buffered messages. - """ - LOG.debug('%r.Transmit()', self) + """Transmit buffered messages.""" + IOLOG.debug('%r.Transmit()', self) written = os.write(self.write_side.fd, self._output_buf[:4096]) self._output_buf = self._output_buf[written:] @@ -390,13 +347,11 @@ class Stream(BasicStream): return bool(self._output_buf) def Enqueue(self, handle, obj): - """ - Enqueue `obj` to `handle`, and tell the broker we have output. - """ - LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj) + """Enqueue `obj` to `handle`, and tell the broker we have output.""" + IOLOG.debug('%r.Enqueue(%r, %r)', self, handle, obj) self._lock.acquire() try: - encoded = self.Pickle((handle, obj)) + encoded = cPickle.dumps((handle, obj), protocol=2) msg = struct.pack('>L', len(encoded)) + encoded self._whmac.update(msg) self._output_buf += self._whmac.digest() + msg @@ -405,10 +360,8 @@ class Stream(BasicStream): self._context.broker.UpdateStream(self) def Disconnect(self): - """ - Close our associated file descriptor and tell registered callbacks the - connection has been destroyed. - """ + """Close our associated file descriptor and tell registered callbacks + the connection has been destroyed.""" LOG.debug('%r.Disconnect()', self) if self._context.GetStream() is self: self._context.SetStream(None) @@ -438,9 +391,8 @@ class Stream(BasicStream): self._context.SetStream(self) def Connect(self): - """ - Connect to a Broker at the address specified in our associated Context. - """ + """Connect to a Broker at the address specified in our associated + Context.""" LOG.debug('%r.Connect()', self) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.read_side = Side(self, sock.fileno()) @@ -462,27 +414,22 @@ class LocalStream(Stream): def __init__(self, context): super(LocalStream, self).__init__(context) self._permitted_classes = set([('econtext.core', 'CallError')]) - self._unpickler.find_global = self._FindGlobal def _FindGlobal(self, module_name, class_name): - """ - Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted. - """ + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" if (module_name, class_name) not in self._permitted_classes: - raise StreamError('context %r attempted to unpickle %r in module %r', + raise StreamError('%r attempted to unpickle %r in module %r', self._context, class_name, module_name) return getattr(sys.modules[module_name], class_name) def AllowClass(self, module_name, class_name): - """ - Add `module_name` to the list of permitted modules. - """ + """Add `module_name` to the list of permitted modules.""" self._permitted_modules.add((module_name, class_name)) - # Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe, - # then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced - # with the context name. Optimized for size. + # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a + # pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is + # replaced with the context name. Optimized for size. def _FirstStage(): import os,sys,zlib R,W=os.pipe() @@ -491,7 +438,7 @@ class LocalStream(Stream): os.dup2(R,0) os.close(R) os.close(W) - os.execv(sys.executable,(CONTEXT_NAME,)) + os.execv(sys.executable,('econtext:'+CONTEXT_NAME,)) else: os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) print 'OK' @@ -524,8 +471,8 @@ class LocalStream(Stream): self._context.key, self._context.broker.log_level, ) - compressed = zlib.compress(source) + compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed write_all(self.write_side.fd, preamble) assert os.read(self.read_side.fd, 3) == 'OK\n' @@ -546,11 +493,10 @@ class SSHStream(LocalStream): class Context(object): """ - Represents a remote context regardless of connection method. + Represent a remote context regardless of connection method. """ - - def __init__(self, broker, name=None, hostname=None, username=None, key=None, - parent_addr=None): + def __init__(self, broker, name=None, hostname=None, username=None, + key=None, parent_addr=None): self.broker = broker self.name = name self.hostname = hostname @@ -573,9 +519,7 @@ class Context(object): return stream def AllocHandle(self): - """ - Allocate a handle. - """ + """Allocate a handle.""" self._lock.acquire() try: self._last_handle += 1L @@ -584,23 +528,19 @@ class Context(object): self._lock.release() def AddHandleCB(self, fn, handle, persist=True): - """ - Register `fn(obj)` to run for each `obj` sent to `handle`. If `persist` - is ``False`` then unregister after one delivery. - """ - LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)', - self, fn, handle, persist) + """Register `fn(obj)` to run for each `obj` sent to `handle`. If + `persist` is ``False`` then unregister after one delivery.""" + IOLOG.debug('%r.AddHandleCB(%r, %r, persist=%r)', + self, fn, handle, persist) self._handle_map[handle] = persist, fn def Enqueue(self, handle, obj): self._stream.Enqueue(handle, obj) def EnqueueAwaitReply(self, handle, deadline, data): - """ - Send `data` to `handle` and wait for a response with an optional + """Send `data` to `handle` and wait for a response with an optional timeout. The message contains `(reply_to, data)`, where `reply_to` is - the handle on which this function expects its reply. - """ + the handle on which this function expects its reply.""" reply_to = self.AllocHandle() LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d', self, handle, deadline, data, reply_to) @@ -650,13 +590,21 @@ class Context(object): return 'Context(%s)' % ', '.join(bits) +class ParentContext(Context): + def SetStream(self, stream): + super(ParentContext, self).SetStream(stream) + if stream is None: + LOG.debug('Parent stream is gone, dying.') + self.broker.Finalize(wait=False) + + class Waker(BasicStream): def __init__(self, broker): self._broker = broker rfd, wfd = os.pipe() self.read_side = Side(self, rfd) self.write_side = Side(self, wfd) - broker.AddStream(self) + broker.UpdateStream(self) def __repr__(self): return '' @@ -665,7 +613,6 @@ class Waker(BasicStream): os.write(self.write_side.fd, ' ') def Receive(self): - LOG.debug('%r: waking %r', self, self._broker) os.read(self.read_side.fd, 1) @@ -677,7 +624,7 @@ class Listener(BasicStream): self._sock.listen(backlog) self._listen_addr = self._sock.getsockname() self.read_side = Side(self, self._sock.fileno()) - broker.AddStream(self) + broker.UpdateStream(self) def Receive(self): sock, addr = self._sock.accept() @@ -691,10 +638,12 @@ class IoLogger(BasicStream): def __init__(self, broker, name): self._broker = broker self._name = name + self._log = logging.getLogger(name) rfd, wfd = os.pipe() + self.read_side = Side(self, rfd) self.write_side = Side(self, wfd) - self._broker.AddStream(self) + self._broker.UpdateStream(self) def __repr__(self): return '' % (self._name, self.read_side.fd) @@ -702,9 +651,10 @@ class IoLogger(BasicStream): def _LogLines(self): while self._buf.find('\n') != -1: line, _, self._buf = self._buf.partition('\n') - LOG.debug('%s: %s', self._name, line.rstrip('\n')) + self._log.debug('%s: %s', self._name, line.rstrip('\n')) def Receive(self): + LOG.debug('%r.Receive()', self) buf = os.read(self.read_side.fd, 4096) if not buf: return self.Disconnect() @@ -720,7 +670,7 @@ class Broker(object): """ _waker = None - def __init__(self, log_level=logging.INFO): + def __init__(self, log_level=logging.DEBUG): self.log_level = log_level self._alive = True @@ -735,13 +685,11 @@ class Broker(object): self._thread.start() def CreateListener(self, address=None, backlog=30): - """ - Listen on `address `for connections from newly spawned contexts. - """ + """Listen on `address `for connections from newly spawned contexts.""" self._listener = Listener(self, address, backlog) def _UpdateStream(self, stream): - LOG.debug('_UpdateStream(%r)', stream) + IOLOG.debug('_UpdateStream(%r)', stream) self._lock.acquire() try: if stream.ReadMore() and stream.read_side.fileno(): @@ -757,38 +705,28 @@ class Broker(object): self._lock.release() def UpdateStream(self, stream): - LOG.debug('UpdateStream(%r)', stream) self._UpdateStream(stream) if self._waker: self._waker.Wake() - def AddStream(self, stream): - self.UpdateStream(stream) - def Register(self, context): - """ - Put a context under control of this broker. - """ + """Put a context under control of this broker.""" LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context, context.GetStream().read_side, context.GetStream().write_side) - self.AddStream(context.GetStream()) + self.UpdateStream(context.GetStream()) self._contexts[context.name] = context return context - def GetLocal(self, name='econtext-local'): - """ - Get the named context running on the local machine, creating it if it - does not exist. - """ + def GetLocal(self, name='default'): + """Get the named context running on the local machine, creating it if + it does not exist.""" context = Context(self, name) context.SetStream(LocalStream(context)).Connect() return self.Register(context) def GetRemote(self, hostname, username, name=None, python_path=None): - """ - Get the named remote context, creating it if it does not exist. - """ + """Get the named remote context, creating it if it does not exist.""" if name is None: name = 'econtext[%s@%s:%d]' %\ (username, socket.gethostname(), os.getpid()) @@ -810,24 +748,20 @@ class Broker(object): self._UpdateStream(stream) def _LoopOnce(self): - LOG.debug('%r.Loop()', self) - #LOG.debug('readers = %r', self._readers) - #LOG.debug('rfds = %r', [r.fileno() for r in self._readers]) - #LOG.debug('writers = %r', self._writers) - #LOG.debug('wfds = %r', [w.fileno() for w in self._writers]) + IOLOG.debug('%r.Loop()', self) + IOLOG.debug('readers = %r', [(r.fileno(), r) for r in self._readers]) + IOLOG.debug('writers = %r', [(w.fileno(), w) for w in self._writers]) rsides, wsides, _ = select.select(self._readers, self._writers, ()) for side in rsides: - LOG.debug('%r: POLLIN for %r', self, side.stream) + IOLOG.debug('%r: POLLIN for %r', self, side.stream) self._CallAndUpdate(side.stream, side.stream.Receive) for side in wsides: - LOG.debug('%r: POLLOUT for %r', self, side.stream) + IOLOG.debug('%r: POLLOUT for %r', self, side.stream) self._CallAndUpdate(side.stream, side.stream.Transmit) def _BrokerMain(self): - """ - Handle stream events until Finalize() is called. - """ + """Handle events until Finalize() is called.""" try: while self._alive: self._LoopOnce() @@ -840,18 +774,15 @@ class Broker(object): LOG.exception('Loop() crashed') def Wait(self): - """ - Wait for the broker to stop. - """ + """Wait for the broker to stop.""" self._thread.join() - def Finalize(self): - """ - Tell all active streams to disconnect. - """ + def Finalize(self, wait=True): + """Disconect all streams and wait for broker to stop.""" self._alive = False self._waker.Wake() - self.Wait() + if wait: + self.Wait() def __repr__(self): return 'Broker()' @@ -868,10 +799,6 @@ class ExternalContext(object): if hasattr(klass, '__module__'): klass.__module__ = 'econtext.core' - def _SetupLogging(self, log_level): - logging.basicConfig(level=log_level, filename='slave.txt') - logging.getLogger('').handlers[0].formatter = Formatter(False) - def _ReapFirstStage(self): os.wait() os.dup2(100, 0) @@ -879,11 +806,20 @@ class ExternalContext(object): def _SetupMaster(self, key): self.broker = Broker() - self.context = Context(self.broker, 'parent', key=key) + self.context = ParentContext(self.broker, 'parent', key=key) self.channel = Channel(self.context, CALL_FUNCTION) self.stream = Stream(self.context) self.stream.Accept(0, 1) + def _SetupLogging(self, log_level): + logging.basicConfig(level=log_level, stream=open('slave', 'w', 1)) + return + logging.basicConfig(level=log_level) + root = logging.getLogger() + root.setLevel(log_level) + root.handlers = [LogHandler(self.context)] + LOG.info('Connected to %s', self.context) + def _SetupImporter(self): self.importer = SlaveModuleImporter(self.context) sys.meta_path.append(self.importer) @@ -893,9 +829,17 @@ class ExternalContext(object): self.stderr_log = IoLogger(self.broker, 'stderr') os.dup2(self.stdout_log.write_side.fd, 1) os.dup2(self.stderr_log.write_side.fd, 2) - os.close(0) + + fp = file('/dev/null') + try: + os.dup2(fp.fileno(), 0) + finally: + fp.close() def _DispatchCalls(self): + #signal.alarm(10) + signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize()) + for data in self.channel: LOG.debug('_DispatchCalls(%r)', data) reply_to, with_context, modname, klass, func, args, kwargs = data @@ -912,24 +856,16 @@ class ExternalContext(object): self.context.Enqueue(reply_to, CallError(e)) def main(self, context_name, key, log_level): - self._FixupMainModule() - self._SetupLogging(log_level) - - syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) - syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) - LOG.debug('ExternalContext.main(%r, %r)', context_name, key) - + import stack self._ReapFirstStage() + self._FixupMainModule() self._SetupMaster(key) + self._SetupLogging(log_level) self._SetupImporter() - #self._SetupStdio() - os.dup2(2, 1) - if 0: - fd = open('/dev/null', 'w') - os.dup2(fd.fileno(), 1) - os.dup2(fd.fileno(), 2) + self._SetupStdio() self.broker.Register(self.context) + LOG.info('entering dispatchcalls') self._DispatchCalls() self.broker.Wait() LOG.debug('ExternalContext.main() exitting')