Working recursive module import OS X

* Replace Log() with logging package.
* Better namespace handle ranges.
* Replace poll() with select().
* Implement module loader.
* Pickle protocol 2.
* Abstract BasicStream and generalize Loop()
* base64 rather than hex.
* Reuse stdio don't connect back (yet).
* Get rid of use_channel for now.
* Fix input_buffer arithmetic.
* Extraneous os.wait()?
* Move demos to subdir child can't access.
pull/35/head
David Wilson 8 years ago
parent 4666cbb435
commit 1a30570057

@ -12,6 +12,7 @@ import getpass
import hmac
import imp
import inspect
import logging
import os
import random
import select
@ -32,13 +33,11 @@ import zlib
# Module-level data.
#
GET_MODULE = 0L
CALL_FUNCTION = 1L
LOG = logging.getLogger('econtext')
DEBUG = True
GET_MODULE = 100L
CALL_FUNCTION = 101L
import sys
sys.stderr = open('milf1', 'w', 1)
#
# Exceptions.
@ -66,12 +65,6 @@ class TimeoutError(StreamError):
# Helpers.
#
def Log(fmt, *args):
if DEBUG:
sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(),
(fmt % args).replace('econtext.', '')))
def write_all(fd, s):
written = 0
while written < len(s):
@ -104,11 +97,30 @@ def CreateChild(*args):
raise SystemExit
childfp.close()
Log('CreateChild() child %d fd %d, parent %d, args %r',
pid, parentfp.fileno(), os.getpid(), args)
LOG.debug('CreateChild() child %d fd %d, parent %d, args %r',
pid, parentfp.fileno(), os.getpid(), args)
return pid, parentfp
class Formatter(logging.Formatter):
FMT = '%(asctime)s %(levelname).1s %(name)s: %(message)s'
DATEFMT = '%H:%M:%S'
def __init__(self, parent):
self.parent = parent
super(Formatter, self).__init__(self.FMT, self.DATEFMT)
def format(self, record):
s = super(Formatter, self).format(record)
if 1:
p = ''
elif self.parent:
p = '\x1b[32m'
else:
p = '\x1b[36m'
return p + ('{%s} %s' % (os.getpid(), s))
class PartialFunction(object):
'''
Partial function implementation.
@ -149,7 +161,7 @@ class Channel(object):
object
)
'''
Log('%r._InternalReceive(%r, %r)', self, killed, data)
LOG.debug('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue_lock.acquire()
try:
self._queue.append((killed or data[0], killed or data[1]))
@ -161,14 +173,14 @@ class Channel(object):
'''
Indicate this channel is closed to the remote side.
'''
Log('%r.Close()', self)
LOG.debug('%r.Close()', self)
self._stream.Enqueue(handle, (True, None))
def Send(self, data):
'''
Send the given object to the remote side.
'''
Log('%r.Send(%r)', self, data)
LOG.debug('%r.Send(%r)', self, data)
self._stream.Enqueue(handle, (False, data))
def Receive(self, timeout=None):
@ -182,7 +194,7 @@ class Channel(object):
Returns:
object
'''
Log('%r.Receive(%r)', self, timeout)
LOG.debug('%r.Receive(%r)', self, timeout)
if not self._queue:
self._wake_event.wait(timeout)
if not self._wake_event.isSet():
@ -191,10 +203,10 @@ class Channel(object):
self._queue_lock.acquire()
try:
self._wake_event.clear()
Log('%r.Receive() queue is %r', self, self._queue)
closed, data = self._queue.pop(0)
Log('%r.Receive() got closed=%r, data=%r', self, closed, data)
if closed:
LOG.debug('%r.Receive() queue is %r', self, self._queue)
killed, data = self._queue.pop(0)
LOG.debug('%r.Receive() got killed=%r, data=%r', self, killed, data)
if killed:
raise ChannelError('Channel is closed.')
return data
finally:
@ -231,26 +243,64 @@ class SlaveModuleImporter(object):
self._context = context
def find_module(self, fullname, path=None):
if not imp.find_module(fullname):
LOG.debug('SlaveModuleImporter.find_module(%r)', fullname)
try:
imp.find_module(fullname)
except ImportError:
LOG.debug('find_module(%r) returning self', fullname)
return self
def load_module(self, fullname):
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
LOG.debug('SlaveModuleImporter.load_module(%r)', fullname)
ret = self._context.EnqueueAwaitReply(GET_MODULE, None, (fullname,))
if ret is None:
raise ImportError('Master does not have %r' % (fullname,))
kind, path, data = ret
code = compile(zlib.decompress(data), path, 'exec')
module = imp.new_module(fullname)
sys.modules[fullname] = module
eval(code, vars(module), vars(module))
return module
def GetModule(cls, killed, fullname):
Log('%r.GetModule(%r, %r)', cls, killed, fullname)
class MasterModuleResponder(object):
def __init__(self, stream):
self._stream = stream
def GetModule(self, killed, (_, (reply_handle, fullname))):
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
if killed:
return
if fullname in sys.modules:
pass
GetModule = classmethod(GetModule)
mod = sys.modules.get(fullname)
if mod:
source = zlib.compress(inspect.getsource(mod))
path = os.path.abspath(mod.__file__)
self._stream.Enqueue(reply_handle, ('source', path, source))
#
# Stream implementations.
#
class Stream(object):
class BasicStream(object):
def fileno(self):
return self._fd
def Disconnect(self):
LOG.debug('%r: disconnect on %r fd %d', self._broker, self, self._fd)
self._broker.RemoveStream(self)
def ReadMore(self):
return True
def WriteMore(self):
return False
class Stream(BasicStream):
def __init__(self, context):
'''
Initialize a new Stream instance.
@ -266,7 +316,7 @@ class Stream(object):
self._rhmac = hmac.new(context.key, digestmod=sha.new)
self._whmac = self._rhmac.copy()
self._last_handle = 0
self._last_handle = 1000L
self._handle_map = {}
self._handle_lock = threading.Lock()
@ -274,7 +324,7 @@ class Stream(object):
self._func_ref_lock = threading.Lock()
self._pickler_file = cStringIO.StringIO()
self._pickler = cPickle.Pickler(self._pickler_file)
self._pickler = cPickle.Pickler(self._pickler_file, protocol=2)
self._pickler.persistent_id = self._CheckFunctionPerID
self._unpickler_file = cStringIO.StringIO()
@ -307,7 +357,7 @@ class Stream(object):
Returns:
object
'''
Log('%r.Unpickle(%r)', self, data)
LOG.debug('%r.Unpickle(%r)', self, data)
self._unpickler_file.write(data)
self._unpickler_file.seek(0)
data = self._unpickler.load()
@ -369,7 +419,8 @@ class Stream(object):
handle: long
persist: False to only receive a single message.
'''
Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist)
LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)',
self, fn, handle, persist)
self._handle_lock.acquire()
try:
self._handle_map[handle] = persist, fn
@ -381,30 +432,36 @@ class Stream(object):
Handle the next complete message on the stream. Raise
CorruptMessageError or IOError on failure.
'''
Log('%r.Receive()', self)
LOG.debug('%r.Receive()', self)
buf = os.read(self._fd, 4096)
if not buf:
return self.Disconnect()
self._input_buf += os.read(self._fd, 4096)
self._input_buf += buf
if len(self._input_buf) < 24:
return
msg_mac = self._input_buf[:20]
msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
if len(self._input_buf) < msg_len-24:
LOG.debug('Input too short')
return
self._rhmac.update(self._input_buf[20:msg_len+24])
expected_mac = self._rhmac.digest()
if msg_mac != expected_mac:
raise CorruptMessageError('%r got invalid MAC: expected %r, got %r',
self, msg_mac.encode('hex'),
expected_mac.encode('hex'))
self, msg_mac.encode('hex'),
expected_mac.encode('hex'))
try:
handle, data = self.Unpickle(self._input_buf[24:msg_len+24])
self._input_buf = self._input_buf[msg_len+4:]
self._input_buf = self._input_buf[msg_len+24:]
handle = long(handle)
Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data)
LOG.debug('%r.Receive(): decoded handle=%r; data=%r',
self, handle, data)
persist, fn = self._handle_map[handle]
if not persist:
del self._handle_map[handle]
@ -413,6 +470,7 @@ class Stream(object):
except (TypeError, ValueError), ex:
raise CorruptMessageError('%r got invalid message: %s', self, ex)
LOG.debug('Calling %r (%r, %r)', fn, False, data)
fn(False, data)
def Transmit(self):
@ -425,9 +483,11 @@ class Stream(object):
Raises:
IOError
'''
Log('%r.Transmit()', self)
LOG.debug('%r.Transmit()', self)
written = os.write(self._fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
def WriteMore(self):
return bool(self._output_buf)
def Enqueue(self, handle, obj):
@ -439,7 +499,7 @@ class Stream(object):
handle: long
obj: object
'''
Log('%r.Enqueue(%r, %r)', self, handle, obj)
LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj)
self._output_buf_lock.acquire()
try:
@ -449,49 +509,51 @@ class Stream(object):
self._output_buf += self._whmac.digest() + msg
finally:
self._output_buf_lock.release()
self._context.broker.Register(self._context)
self._context.broker.UpdateStream(self, wake=True)
def Disconnect(self):
'''
Close our associated file descriptor and tell any registered callbacks
that the connection has been destroyed.
'''
Log('%r.Disconnect()', self)
LOG.debug('%r.Disconnect()', self)
try:
os.close(self._fd)
except OSError, e:
Log('%r.Disconnect(): did not close fd %s: %s', self, self._fd, e)
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
self, self._fd, e)
self._fd = None
if self._context.GetStream() is self:
self._context.SetStream(None)
for handle, (persist, fn) in self._handle_map.iteritems():
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
self, handle, fn)
LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r',
self, handle, fn)
fn(True, None)
@classmethod
def Accept(cls, context, sock):
def Accept(cls, context, fd):
'''
'''
stream = cls(context)
stream.sock = sock
stream._fd = sock.fileno()
stream._fd = os.dup(fd)
context.SetStream(stream)
context.broker.Register(context)
return stream
def Connect(self):
'''
Connect to a Broker at the address specified in our associated Context.
'''
Log('%r.Connect()', self)
LOG.debug('%r.Connect()', self)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._fd = sock.fileno()
sock.connect(self._context.parent_addr)
self.Enqueue(0, self._context.name)
def fileno(self):
return self._fd
def __repr__(self):
return 'econtext.%s(<context=%r>)' %\
(self.__class__.__name__, self._context)
@ -509,9 +571,11 @@ class LocalStream(Stream):
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_modules = {}
self._permitted_modules = set(['exceptions'])
self._unpickler.find_global = self._FindGlobal
self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE)
self.responder = MasterModuleResponder(self)
self.AddHandleCB(self.responder.GetModule, handle=GET_MODULE)
def _FindGlobal(self, module_name, class_name):
'''
@ -527,7 +591,7 @@ class LocalStream(Stream):
'''
if module_name not in self._permitted_modules:
raise StreamError('context %r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowModule(self, module_name):
@ -561,23 +625,26 @@ class LocalStream(Stream):
source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:]))
source = source.replace(' ', '\t')
source = source.replace('CONTEXT_NAME', repr(self._context.name))
encoded = source.encode('base64').replace('\n', '')
return [self.python_path, '-c',
'exec "%s".decode("hex")' % (source.encode('hex'),)]
'exec "%s".decode("base64")' % (encoded,)]
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._context)
def Connect(self):
Log('%r.Connect()', self)
LOG.debug('%r.Connect()', self)
pid, sock = CreateChild(*self.GetBootCommand())
self._fd = os.dup(sock.fileno())
sock.close()
Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
LOG.debug('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextMain(%r, %r, %r)\n' %\
(self._context.name, self._context.broker._listen_addr,
self._context.key)
source += '\nExternalContextMain(%r, %r, %r)\n' % (
self._context.name,
self._context.broker._listener._listen_addr,
self._context.key
)
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
@ -605,7 +672,7 @@ class Context(object):
'''
def __init__(self, broker, name=None, hostname=None, username=None, key=None,
parent_addr=None):
parent_addr=None):
self.broker = broker
self.name = name
self.hostname = hostname
@ -626,18 +693,20 @@ class Context(object):
optional timeout. The message contains (reply_handle, data), where
reply_handle is the handle on which this function expects its reply.
'''
Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data)
reply_handle = self._stream.AllocHandle()
reply_event = threading.Event()
container = []
LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d',
self, handle, deadline, data, reply_handle)
def _Receive(killed, data):
Log('%r._Receive(%r, %r)', self, killed, data)
LOG.debug('%r._Receive(%r, %r)', self, killed, data)
container.extend([killed, data])
reply_event.set()
self._stream.AddHandleCB(_Receive, reply_handle, persist=False)
self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data))
self._stream.Enqueue(handle, (False, (reply_handle,) + data))
reply_event.wait(deadline)
if not reply_event.isSet():
@ -648,21 +717,20 @@ class Context(object):
if killed:
raise StreamError('lost connection during call.')
Log('%r._EnqueueAwaitReply(): got reply: %r', self, data)
LOG.debug('%r._EnqueueAwaitReply(): got reply: %r', self, data)
return data
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
kwargs)
LOG.debug('%r.CallWithDeadline(%r, %r, *%r, **%r)',
self, fn, deadline, args, kwargs)
use_channel = bool(kwargs.pop('use_channel', False))
if isinstance(fn, types.MethodType) and \
isinstance(fn.im_self, (type, types.ClassType)):
fn_class = fn.im_self.__name__
else:
fn_class = None
call = (use_channel, fn.__module__, fn_class, fn.__name__, args, kwargs)
call = (fn.__module__, fn_class, fn.__name__, args, kwargs)
success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
if success:
@ -680,6 +748,37 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits)
class Waker(BasicStream):
def __init__(self, broker):
self._broker = broker
self._rfd, self._wfd = os.pipe()
self._fd = self._rfd
broker.AddStream(self)
def Wake(self):
os.write(self._wfd, ' ')
def Receive(self):
LOG.debug('%r: waking %r', self, self._broker)
os.read(self._rfd, 1)
class Listener(BasicStream):
def __init__(self, broker, address=None, backlog=30):
self._broker = broker
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.bind(address or ('0.0.0.0', 0))
self._sock.listen(backlog)
self._listen_addr = self._sock.getsockname()
self._fd = self._sock.fileno()
broker.AddStream(self)
def Receive(self):
sock, addr = self._sock.accept()
context = Context(self._broker, name=addr)
Stream.Accept(context, sock.fileno())
class Broker(object):
'''
Context broker: this is responsible for keeping track of contexts, any
@ -688,16 +787,14 @@ class Broker(object):
def __init__(self):
self._dead = False
self._poller = select.poll()
self._poller_fd_map = {}
self._poller_lock = threading.Lock()
self._lock = threading.Lock()
self._contexts = {}
self._readers = set()
self._writers = set()
self._waker = None
self._waker = Waker(self)
self._wake_rfd, self._wake_wfd = os.pipe()
self._listen_sock = None
self._poller.register(self._wake_rfd)
self._thread = threading.Thread(target=self.Loop, name='Broker')
self._thread = threading.Thread(target=self._Loop, name='Broker')
self._thread.setDaemon(True)
self._thread.start()
@ -708,25 +805,41 @@ class Broker(object):
address: The IPv4 address tuple to listen on.
backlog: Number of connections to accept while broker thread is busy.
'''
self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._listen_sock.bind(address or ('0.0.0.0', 0))
self._listen_sock.listen(backlog)
self._listen_addr = self._listen_sock.getsockname()
self._poller.register(self._listen_sock)
self._listener = Listener(self, address, backlog)
def UpdateStream(self, stream, wake=False):
LOG.debug('UpdateStream(%r, wake=%s)', stream, wake)
fileno = stream.fileno()
if fileno is not None and stream.ReadMore():
self._readers.add(stream)
else:
self._readers.discard(stream)
if fileno is not None and stream.WriteMore():
self._writers.add(stream)
else:
self._writers.discard(stream)
if wake:
self._waker.Wake()
def AddStream(self, stream):
self._lock.acquire()
try:
if self._waker:
self._waker.Wake()
self.UpdateStream(stream)
finally:
self._lock.release()
def Register(self, context):
'''
Put a context under control of this broker.
'''
Log('%r.Register(%r) -> fd=%r', self, context, context.GetStream().fileno())
self._poller_lock.acquire()
os.write(self._wake_wfd, ' ')
try:
self._contexts[context.name] = context
self._poller.register(context.GetStream())
self._poller_fd_map[context.GetStream().fileno()] = context.GetStream()
finally:
self._poller_lock.release()
LOG.debug('%r.Register(%r) -> fd=%r', self, context,
context.GetStream().fileno())
self.AddStream(context.GetStream())
self._contexts[context.name] = context
return context
def GetLocal(self, name):
@ -754,53 +867,48 @@ class Broker(object):
context.SetStream(SSHStream(context)).Connect()
return self.Register(context)
def _Loop(self):
try:
self.Loop()
except Exception:
LOG.exception('Loop() crashed')
def Loop(self):
'''
Handle stream events until Finalize() is called.
'''
while not self._dead:
Log('%r.Loop()', self)
self._poller_lock.acquire()
self._poller_lock.release()
for fd, event in self._poller.poll():
if fd == self._wake_rfd:
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
os.read(self._wake_rfd, 1)
continue
elif self._listen_sock and fd == self._listen_sock.fileno():
context = Context(self)
sock, addr = self._listen_sock.accept()
Stream.Accept(context, sock)
continue
obj = self._poller_fd_map[fd]
if event & select.POLLHUP:
Log('%r: POLLHUP for %d, %r', self, fd, obj)
obj.Disconnect()
elif event & select.POLLIN:
Log('%r: POLLIN for %d, %r', self, fd, obj)
obj.Receive()
elif event & select.POLLOUT:
Log('%r: POLLOUT for %d, %r', self, fd, obj)
if not obj.Transmit(): # If no output buffered, unset POLLOUT.
self._poller.unregister(obj)
self._poller.register(obj, select.POLLIN)
elif event & select.POLLNVAL:
Log('%r: POLLNVAL for %d, %r', self, fd, obj)
obj.Disconnect()
self._poller.unregister(obj)
LOG.debug('%r.Loop()', self)
self._lock.acquire()
self._lock.release()
#LOG.debug('readers = %r', self._readers)
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
#LOG.debug('writers = %r', self._writers)
rstrms, wstrms, _ = select.select(self._readers, self._writers, ())
for stream in rstrms:
LOG.debug('%r: POLLIN for %r', self, stream)
stream.Receive()
self.UpdateStream(stream)
for stream in wstrms:
LOG.debug('%r: POLLOUT for %r', self, stream)
stream.Transmit()
self.UpdateStream(stream)
def Finalize(self):
'''
Tell all active streams to disconnect.
'''
self._dead = True
self._poller_lock.acquire()
self._lock.acquire()
try:
for name, context in self._contexts.iteritems():
context.GetStream().Disconnect()
stream = context.GetStream()
if stream:
stream.Disconnect()
finally:
self._poller_lock.release()
self._lock.release()
def __repr__(self):
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
@ -809,25 +917,38 @@ class Broker(object):
def ExternalContextMain(context_name, parent_addr, key):
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
os.wait() # Reap the first stage.
logging.basicConfig(level=logging.DEBUG)
logging.getLogger('').handlers[0].formatter = Formatter(False)
LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
# os.wait() # Reap the first stage.
os.dup2(100, 0)
os.close(100)
broker = Broker()
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = context.SetStream(Stream(context))
stream.Connect()
stream = Stream.Accept(context, 0)
os.close(0)
# stream = context.SetStream(Stream(context))
# stream.
# stream.Connect()
broker.Register(context)
importer = SlaveModuleImporter(context)
sys.meta_path.append(importer)
for call_info in Channel(stream, CALL_FUNCTION):
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)
reply_handle, mod_name, func_name, args, kwargs = call_info
fn = getattr(__import__(mod_name), func_name)
LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info)
(reply_handle, mod_name, class_name, func_name, args, kwargs) = call_info
try:
fn = getattr(__import__(mod_name), func_name)
stream.Enqueue(reply_handle, (True, fn(*args, **kwargs)))
except Exception, e:
stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
broker.Finalize()
LOG.error('ExternalContextMain exitting')

25
st.py

@ -1,25 +0,0 @@
import socket
def GetCurrentHostname():
'''
Fetch the current hostname.
'''
return socket.gethostname()
def LogCurrentUptime(hostname, pathname='/tmp/uptime.txt'):
'''
Log the current uptime along with process ID that logs it.
Args:
hostname: the string hostname.
'''
fp = file(pathname, 'a')
fp.write('%d %s %s\n' % (os.getpid(), hostname, os.popen('uptime').read()))
fp.close()
def try_something_silly(arg):
file('tty', 'w').write('ARG WAS: ' + str(arg) + '\n')
Loading…
Cancel
Save