Fixes/improvements

Get rid of syslog.
Get rid of section comments.
Move IOLOG to separate logger to avoid infinite loop.
Change function docstring style.
Delete log Formatter.
Implement LogHandler to forward logs to parent.
Delete Pickle(), simplify Unpickle().
Have slave Finalize() when parent disconnects.
Delete AddStream().
pull/35/head
David Wilson 8 years ago
parent 9e2b1d24be
commit 0b0e828c04

@ -21,7 +21,6 @@ import signal
import socket
import struct
import sys
import syslog
import textwrap
import threading
import traceback
@ -29,51 +28,45 @@ import types
import zlib
#
# Module-level data.
#
LOG = logging.getLogger('econtext')
IOLOG = logging.getLogger('econtext.io')
RLOG = logging.getLogger('econtext.ctx')
GET_MODULE = 100L
CALL_FUNCTION = 101L
FORWARD_LOG = 102L
#
# Exceptions.
#
class ContextError(Exception):
'Raised when a problem occurs with a context.'
"""Raised when a problem occurs with a context."""
def __init__(self, fmt, *args):
Exception.__init__(self, fmt % args)
class ChannelError(ContextError):
'Raised when a channel dies or has been closed.'
"""Raised when a channel dies or has been closed."""
class StreamError(ContextError):
'Raised when a stream cannot be established.'
"""Raised when a stream cannot be established."""
class CorruptMessageError(StreamError):
'Raised when a corrupt message is received on a stream.'
"""Raised when a corrupt message is received on a stream."""
class TimeoutError(StreamError):
'Raised when a timeout occurs on a stream.'
"""Raised when a timeout occurs on a stream."""
class CallError(ContextError):
'Raised when .Call() fails'
class CallError(ContextError):
"""Raised when .Call() fails"""
def __init__(self, e):
name = '%s.%s' % (type(e).__module__, type(e).__name__)
stack = ''.join(traceback.format_stack(sys.exc_info[2]))
ContextError.__init__(self, 'Call failed: %s: %s\n%s', name, e, stack)
#
# Helpers.
#
class Dead(object):
def __eq__(self, other):
return type(other) is Dead
@ -94,10 +87,8 @@ def write_all(fd, s):
def CreateChild(*args):
"""
Create a child process whose stdin/stdout is connected to a socket,
returning `(pid, socket_obj)`.
"""
"""Create a child process whose stdin/stdout is connected to a socket,
returning `(pid, socket_obj)`."""
parentfp, childfp = socket.socketpair()
pid = os.fork()
if not pid:
@ -114,25 +105,6 @@ def CreateChild(*args):
return pid, parentfp
class Formatter(logging.Formatter):
FMT = '%(asctime)s %(levelname).1s %(name)s: %(message)s'
DATEFMT = '%H:%M:%S'
def __init__(self, parent):
self.parent = parent
super(Formatter, self).__init__(self.FMT, self.DATEFMT)
def format(self, record):
s = super(Formatter, self).format(record)
if 1:
p = ''
elif self.parent:
p = '\x1b[32m'
else:
p = '\x1b[36m'
return p + ('{%s} %s' % (os.getpid(), s))
class Channel(object):
def __init__(self, context, handle):
self._context = context
@ -141,47 +113,37 @@ class Channel(object):
self._context.AddHandleCB(self._Receive, handle)
def _Receive(self, data):
"""
Callback from the Stream; appends data to the internal queue.
"""
LOG.debug('%r._Receive(%r)', self, data)
"""Callback from the Stream; appends data to the internal queue."""
IOLOG.debug('%r._Receive(%r)', self, data)
self._queue.put(data)
def Close(self):
"""
Indicate this channel is closed to the remote side.
"""
LOG.debug('%r.Close()', self)
"""Indicate this channel is closed to the remote side."""
IOLOG.debug('%r.Close()', self)
self._context.Enqueue(handle, _DEAD)
def Send(self, data):
"""
Send `data` to the remote.
"""
LOG.debug('%r.Send(%r)', self, data)
"""Send `data` to the remote."""
IOLOG.debug('%r.Send(%r)', self, data)
self._context.Enqueue(handle, data)
def Receive(self, timeout=None):
"""
Receive an object from the remote, or return ``None`` if `timeout` is
reached.
"""
LOG.debug('%r.Receive(timeout=%r)', self, timeout)
"""Receive an object from the remote, or return ``None`` if `timeout`
is reached."""
IOLOG.debug('%r.Receive(timeout=%r)', self, timeout)
try:
data = self._queue.get(True, timeout)
except Queue.Empty:
return
LOG.debug('%r.Receive() got %r', self, data)
IOLOG.debug('%r.Receive() got %r', self, data)
if data == _DEAD:
raise ChannelError('Channel is closed.')
return data
def __iter__(self):
"""
Return an iterator that yields objects arriving on this channel, until
the channel dies or is closed.
"""
"""Iterate objects arriving on this channel, until the channel dies or
is closed."""
while True:
try:
yield self.Receive()
@ -244,21 +206,37 @@ class MasterModuleResponder(object):
self._context.Enqueue(reply_to, None)
class LogHandler(logging.Handler):
def __init__(self, context):
logging.Handler.__init__(self)
self.context = context
self.local = threading.local()
def emit(self, rec):
if rec.name == 'econtext.io' or \
getattr(self.local, 'in_commit', False):
return
self.local.in_commit = True
try:
msg = self.format(rec)
self.context.Enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg))
finally:
self.local.in_commit = False
class LogForwarder(object):
def __init__(self, context):
self._context = context
self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG)
self._log = RLOG.getChild(self._context.name)
def ForwardLog(self, data):
if data == _DEAD:
return
LOG.debug('%r: %s', self._context, data)
#
# Stream implementations.
#
name, level, s = data
self._log.log(level, '%s: %s', name, s)
class Side(object):
@ -300,43 +278,24 @@ class Stream(BasicStream):
def __init__(self, context):
self._context = context
self._lock = threading.Lock()
self._rhmac = hmac.new(context.key, digestmod=sha.new)
self._whmac = self._rhmac.copy()
self._pickler_file = cStringIO.StringIO()
self._pickler = cPickle.Pickler(self._pickler_file, protocol=2)
self._unpickler_file = cStringIO.StringIO()
self._unpickler = cPickle.Unpickler(self._unpickler_file)
def Pickle(self, obj):
"""
Serialize `obj` into a bytestring.
"""
self._pickler_file.truncate(0)
self._pickler_file.seek(0)
self._pickler.clear_memo()
self._pickler.dump(obj)
return self._pickler_file.getvalue()
_FindGlobal = None
def Unpickle(self, data):
"""
Deserialize `data` into an object.
"""
LOG.debug('%r.Unpickle(%r)', self, data)
self._unpickler_file.truncate(0)
self._unpickler_file.seek(0)
self._unpickler_file.write(data)
self._unpickler_file.seek(0)
return self._unpickler.load()
"""Deserialize `data` into an object."""
IOLOG.debug('%r.Unpickle(%r)', self, data)
fp = cStringIO.StringIO(data)
unpickler = cPickle.Unpickler(fp)
if self._FindGlobal:
unpickler.find_global = self._FindGlobal
return unpickler.load()
def Receive(self):
"""
Handle the next complete message on the stream. Raise
CorruptMessageError or IOError on failure.
"""
LOG.debug('%r.Receive()', self)
"""Handle the next complete message on the stream. Raise
CorruptMessageError or IOError on failure."""
IOLOG.debug('%r.Receive()', self)
buf = os.read(self.read_side.fd, 4096)
if not buf:
@ -349,7 +308,7 @@ class Stream(BasicStream):
msg_mac = self._input_buf[:20]
msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
if len(self._input_buf) < msg_len-24:
LOG.debug('Input too short')
IOLOG.debug('Input too short')
return
self._rhmac.update(self._input_buf[20:msg_len+24])
@ -379,10 +338,8 @@ class Stream(BasicStream):
fn(data)
def Transmit(self):
"""
Transmit buffered messages.
"""
LOG.debug('%r.Transmit()', self)
"""Transmit buffered messages."""
IOLOG.debug('%r.Transmit()', self)
written = os.write(self.write_side.fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
@ -390,13 +347,11 @@ class Stream(BasicStream):
return bool(self._output_buf)
def Enqueue(self, handle, obj):
"""
Enqueue `obj` to `handle`, and tell the broker we have output.
"""
LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj)
"""Enqueue `obj` to `handle`, and tell the broker we have output."""
IOLOG.debug('%r.Enqueue(%r, %r)', self, handle, obj)
self._lock.acquire()
try:
encoded = self.Pickle((handle, obj))
encoded = cPickle.dumps((handle, obj), protocol=2)
msg = struct.pack('>L', len(encoded)) + encoded
self._whmac.update(msg)
self._output_buf += self._whmac.digest() + msg
@ -405,10 +360,8 @@ class Stream(BasicStream):
self._context.broker.UpdateStream(self)
def Disconnect(self):
"""
Close our associated file descriptor and tell registered callbacks the
connection has been destroyed.
"""
"""Close our associated file descriptor and tell registered callbacks
the connection has been destroyed."""
LOG.debug('%r.Disconnect()', self)
if self._context.GetStream() is self:
self._context.SetStream(None)
@ -438,9 +391,8 @@ class Stream(BasicStream):
self._context.SetStream(self)
def Connect(self):
"""
Connect to a Broker at the address specified in our associated Context.
"""
"""Connect to a Broker at the address specified in our associated
Context."""
LOG.debug('%r.Connect()', self)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.read_side = Side(self, sock.fileno())
@ -462,27 +414,22 @@ class LocalStream(Stream):
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_classes = set([('econtext.core', 'CallError')])
self._unpickler.find_global = self._FindGlobal
def _FindGlobal(self, module_name, class_name):
"""
Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted.
"""
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in self._permitted_classes:
raise StreamError('context %r attempted to unpickle %r in module %r',
raise StreamError('%r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowClass(self, module_name, class_name):
"""
Add `module_name` to the list of permitted modules.
"""
"""Add `module_name` to the list of permitted modules."""
self._permitted_modules.add((module_name, class_name))
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
# with the context name. Optimized for size.
# base64'd and passed to 'python -c'. It forks, dups 0->100, creates a
# pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is
# replaced with the context name. Optimized for size.
def _FirstStage():
import os,sys,zlib
R,W=os.pipe()
@ -491,7 +438,7 @@ class LocalStream(Stream):
os.dup2(R,0)
os.close(R)
os.close(W)
os.execv(sys.executable,(CONTEXT_NAME,))
os.execv(sys.executable,('econtext:'+CONTEXT_NAME,))
else:
os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input())))
print 'OK'
@ -524,8 +471,8 @@ class LocalStream(Stream):
self._context.key,
self._context.broker.log_level,
)
compressed = zlib.compress(source)
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
write_all(self.write_side.fd, preamble)
assert os.read(self.read_side.fd, 3) == 'OK\n'
@ -546,11 +493,10 @@ class SSHStream(LocalStream):
class Context(object):
"""
Represents a remote context regardless of connection method.
Represent a remote context regardless of connection method.
"""
def __init__(self, broker, name=None, hostname=None, username=None, key=None,
parent_addr=None):
def __init__(self, broker, name=None, hostname=None, username=None,
key=None, parent_addr=None):
self.broker = broker
self.name = name
self.hostname = hostname
@ -573,9 +519,7 @@ class Context(object):
return stream
def AllocHandle(self):
"""
Allocate a handle.
"""
"""Allocate a handle."""
self._lock.acquire()
try:
self._last_handle += 1L
@ -584,23 +528,19 @@ class Context(object):
self._lock.release()
def AddHandleCB(self, fn, handle, persist=True):
"""
Register `fn(obj)` to run for each `obj` sent to `handle`. If `persist`
is ``False`` then unregister after one delivery.
"""
LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
"""Register `fn(obj)` to run for each `obj` sent to `handle`. If
`persist` is ``False`` then unregister after one delivery."""
IOLOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
self._handle_map[handle] = persist, fn
def Enqueue(self, handle, obj):
self._stream.Enqueue(handle, obj)
def EnqueueAwaitReply(self, handle, deadline, data):
"""
Send `data` to `handle` and wait for a response with an optional
"""Send `data` to `handle` and wait for a response with an optional
timeout. The message contains `(reply_to, data)`, where `reply_to` is
the handle on which this function expects its reply.
"""
the handle on which this function expects its reply."""
reply_to = self.AllocHandle()
LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d',
self, handle, deadline, data, reply_to)
@ -650,13 +590,21 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits)
class ParentContext(Context):
def SetStream(self, stream):
super(ParentContext, self).SetStream(stream)
if stream is None:
LOG.debug('Parent stream is gone, dying.')
self.broker.Finalize(wait=False)
class Waker(BasicStream):
def __init__(self, broker):
self._broker = broker
rfd, wfd = os.pipe()
self.read_side = Side(self, rfd)
self.write_side = Side(self, wfd)
broker.AddStream(self)
broker.UpdateStream(self)
def __repr__(self):
return '<Waker>'
@ -665,7 +613,6 @@ class Waker(BasicStream):
os.write(self.write_side.fd, ' ')
def Receive(self):
LOG.debug('%r: waking %r', self, self._broker)
os.read(self.read_side.fd, 1)
@ -677,7 +624,7 @@ class Listener(BasicStream):
self._sock.listen(backlog)
self._listen_addr = self._sock.getsockname()
self.read_side = Side(self, self._sock.fileno())
broker.AddStream(self)
broker.UpdateStream(self)
def Receive(self):
sock, addr = self._sock.accept()
@ -691,10 +638,12 @@ class IoLogger(BasicStream):
def __init__(self, broker, name):
self._broker = broker
self._name = name
self._log = logging.getLogger(name)
rfd, wfd = os.pipe()
self.read_side = Side(self, rfd)
self.write_side = Side(self, wfd)
self._broker.AddStream(self)
self._broker.UpdateStream(self)
def __repr__(self):
return '<IoLogger %s fd %d>' % (self._name, self.read_side.fd)
@ -702,9 +651,10 @@ class IoLogger(BasicStream):
def _LogLines(self):
while self._buf.find('\n') != -1:
line, _, self._buf = self._buf.partition('\n')
LOG.debug('%s: %s', self._name, line.rstrip('\n'))
self._log.debug('%s: %s', self._name, line.rstrip('\n'))
def Receive(self):
LOG.debug('%r.Receive()', self)
buf = os.read(self.read_side.fd, 4096)
if not buf:
return self.Disconnect()
@ -720,7 +670,7 @@ class Broker(object):
"""
_waker = None
def __init__(self, log_level=logging.INFO):
def __init__(self, log_level=logging.DEBUG):
self.log_level = log_level
self._alive = True
@ -735,13 +685,11 @@ class Broker(object):
self._thread.start()
def CreateListener(self, address=None, backlog=30):
"""
Listen on `address `for connections from newly spawned contexts.
"""
"""Listen on `address `for connections from newly spawned contexts."""
self._listener = Listener(self, address, backlog)
def _UpdateStream(self, stream):
LOG.debug('_UpdateStream(%r)', stream)
IOLOG.debug('_UpdateStream(%r)', stream)
self._lock.acquire()
try:
if stream.ReadMore() and stream.read_side.fileno():
@ -757,38 +705,28 @@ class Broker(object):
self._lock.release()
def UpdateStream(self, stream):
LOG.debug('UpdateStream(%r)', stream)
self._UpdateStream(stream)
if self._waker:
self._waker.Wake()
def AddStream(self, stream):
self.UpdateStream(stream)
def Register(self, context):
"""
Put a context under control of this broker.
"""
"""Put a context under control of this broker."""
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
context.GetStream().read_side,
context.GetStream().write_side)
self.AddStream(context.GetStream())
self.UpdateStream(context.GetStream())
self._contexts[context.name] = context
return context
def GetLocal(self, name='econtext-local'):
"""
Get the named context running on the local machine, creating it if it
does not exist.
"""
def GetLocal(self, name='default'):
"""Get the named context running on the local machine, creating it if
it does not exist."""
context = Context(self, name)
context.SetStream(LocalStream(context)).Connect()
return self.Register(context)
def GetRemote(self, hostname, username, name=None, python_path=None):
"""
Get the named remote context, creating it if it does not exist.
"""
"""Get the named remote context, creating it if it does not exist."""
if name is None:
name = 'econtext[%s@%s:%d]' %\
(username, socket.gethostname(), os.getpid())
@ -810,24 +748,20 @@ class Broker(object):
self._UpdateStream(stream)
def _LoopOnce(self):
LOG.debug('%r.Loop()', self)
#LOG.debug('readers = %r', self._readers)
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
#LOG.debug('writers = %r', self._writers)
#LOG.debug('wfds = %r', [w.fileno() for w in self._writers])
IOLOG.debug('%r.Loop()', self)
IOLOG.debug('readers = %r', [(r.fileno(), r) for r in self._readers])
IOLOG.debug('writers = %r', [(w.fileno(), w) for w in self._writers])
rsides, wsides, _ = select.select(self._readers, self._writers, ())
for side in rsides:
LOG.debug('%r: POLLIN for %r', self, side.stream)
IOLOG.debug('%r: POLLIN for %r', self, side.stream)
self._CallAndUpdate(side.stream, side.stream.Receive)
for side in wsides:
LOG.debug('%r: POLLOUT for %r', self, side.stream)
IOLOG.debug('%r: POLLOUT for %r', self, side.stream)
self._CallAndUpdate(side.stream, side.stream.Transmit)
def _BrokerMain(self):
"""
Handle stream events until Finalize() is called.
"""
"""Handle events until Finalize() is called."""
try:
while self._alive:
self._LoopOnce()
@ -840,18 +774,15 @@ class Broker(object):
LOG.exception('Loop() crashed')
def Wait(self):
"""
Wait for the broker to stop.
"""
"""Wait for the broker to stop."""
self._thread.join()
def Finalize(self):
"""
Tell all active streams to disconnect.
"""
def Finalize(self, wait=True):
"""Disconect all streams and wait for broker to stop."""
self._alive = False
self._waker.Wake()
self.Wait()
if wait:
self.Wait()
def __repr__(self):
return 'Broker()'
@ -868,10 +799,6 @@ class ExternalContext(object):
if hasattr(klass, '__module__'):
klass.__module__ = 'econtext.core'
def _SetupLogging(self, log_level):
logging.basicConfig(level=log_level, filename='slave.txt')
logging.getLogger('').handlers[0].formatter = Formatter(False)
def _ReapFirstStage(self):
os.wait()
os.dup2(100, 0)
@ -879,11 +806,20 @@ class ExternalContext(object):
def _SetupMaster(self, key):
self.broker = Broker()
self.context = Context(self.broker, 'parent', key=key)
self.context = ParentContext(self.broker, 'parent', key=key)
self.channel = Channel(self.context, CALL_FUNCTION)
self.stream = Stream(self.context)
self.stream.Accept(0, 1)
def _SetupLogging(self, log_level):
logging.basicConfig(level=log_level, stream=open('slave', 'w', 1))
return
logging.basicConfig(level=log_level)
root = logging.getLogger()
root.setLevel(log_level)
root.handlers = [LogHandler(self.context)]
LOG.info('Connected to %s', self.context)
def _SetupImporter(self):
self.importer = SlaveModuleImporter(self.context)
sys.meta_path.append(self.importer)
@ -893,9 +829,17 @@ class ExternalContext(object):
self.stderr_log = IoLogger(self.broker, 'stderr')
os.dup2(self.stdout_log.write_side.fd, 1)
os.dup2(self.stderr_log.write_side.fd, 2)
os.close(0)
fp = file('/dev/null')
try:
os.dup2(fp.fileno(), 0)
finally:
fp.close()
def _DispatchCalls(self):
#signal.alarm(10)
signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize())
for data in self.channel:
LOG.debug('_DispatchCalls(%r)', data)
reply_to, with_context, modname, klass, func, args, kwargs = data
@ -912,24 +856,16 @@ class ExternalContext(object):
self.context.Enqueue(reply_to, CallError(e))
def main(self, context_name, key, log_level):
self._FixupMainModule()
self._SetupLogging(log_level)
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
LOG.debug('ExternalContext.main(%r, %r)', context_name, key)
import stack
self._ReapFirstStage()
self._FixupMainModule()
self._SetupMaster(key)
self._SetupLogging(log_level)
self._SetupImporter()
#self._SetupStdio()
os.dup2(2, 1)
if 0:
fd = open('/dev/null', 'w')
os.dup2(fd.fileno(), 1)
os.dup2(fd.fileno(), 2)
self._SetupStdio()
self.broker.Register(self.context)
LOG.info('entering dispatchcalls')
self._DispatchCalls()
self.broker.Wait()
LOG.debug('ExternalContext.main() exitting')

Loading…
Cancel
Save