diff --git a/econtext.py b/econtext.py index 7bca7a2a..04f0f719 100644 --- a/econtext.py +++ b/econtext.py @@ -12,6 +12,7 @@ import getpass import hmac import imp import inspect +import logging import os import random import select @@ -32,13 +33,11 @@ import zlib # Module-level data. # -GET_MODULE = 0L -CALL_FUNCTION = 1L +LOG = logging.getLogger('econtext') -DEBUG = True +GET_MODULE = 100L +CALL_FUNCTION = 101L -import sys -sys.stderr = open('milf1', 'w', 1) # # Exceptions. @@ -66,12 +65,6 @@ class TimeoutError(StreamError): # Helpers. # -def Log(fmt, *args): - if DEBUG: - sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(), - (fmt % args).replace('econtext.', ''))) - - def write_all(fd, s): written = 0 while written < len(s): @@ -104,11 +97,30 @@ def CreateChild(*args): raise SystemExit childfp.close() - Log('CreateChild() child %d fd %d, parent %d, args %r', - pid, parentfp.fileno(), os.getpid(), args) + LOG.debug('CreateChild() child %d fd %d, parent %d, args %r', + pid, parentfp.fileno(), os.getpid(), args) return pid, parentfp +class Formatter(logging.Formatter): + FMT = '%(asctime)s %(levelname).1s %(name)s: %(message)s' + DATEFMT = '%H:%M:%S' + + def __init__(self, parent): + self.parent = parent + super(Formatter, self).__init__(self.FMT, self.DATEFMT) + + def format(self, record): + s = super(Formatter, self).format(record) + if 1: + p = '' + elif self.parent: + p = '\x1b[32m' + else: + p = '\x1b[36m' + return p + ('{%s} %s' % (os.getpid(), s)) + + class PartialFunction(object): ''' Partial function implementation. @@ -149,7 +161,7 @@ class Channel(object): object ) ''' - Log('%r._InternalReceive(%r, %r)', self, killed, data) + LOG.debug('%r._InternalReceive(%r, %r)', self, killed, data) self._queue_lock.acquire() try: self._queue.append((killed or data[0], killed or data[1])) @@ -161,14 +173,14 @@ class Channel(object): ''' Indicate this channel is closed to the remote side. ''' - Log('%r.Close()', self) + LOG.debug('%r.Close()', self) self._stream.Enqueue(handle, (True, None)) def Send(self, data): ''' Send the given object to the remote side. ''' - Log('%r.Send(%r)', self, data) + LOG.debug('%r.Send(%r)', self, data) self._stream.Enqueue(handle, (False, data)) def Receive(self, timeout=None): @@ -182,7 +194,7 @@ class Channel(object): Returns: object ''' - Log('%r.Receive(%r)', self, timeout) + LOG.debug('%r.Receive(%r)', self, timeout) if not self._queue: self._wake_event.wait(timeout) if not self._wake_event.isSet(): @@ -191,10 +203,10 @@ class Channel(object): self._queue_lock.acquire() try: self._wake_event.clear() - Log('%r.Receive() queue is %r', self, self._queue) - closed, data = self._queue.pop(0) - Log('%r.Receive() got closed=%r, data=%r', self, closed, data) - if closed: + LOG.debug('%r.Receive() queue is %r', self, self._queue) + killed, data = self._queue.pop(0) + LOG.debug('%r.Receive() got killed=%r, data=%r', self, killed, data) + if killed: raise ChannelError('Channel is closed.') return data finally: @@ -231,26 +243,64 @@ class SlaveModuleImporter(object): self._context = context def find_module(self, fullname, path=None): - if not imp.find_module(fullname): + LOG.debug('SlaveModuleImporter.find_module(%r)', fullname) + try: + imp.find_module(fullname) + except ImportError: + LOG.debug('find_module(%r) returning self', fullname) return self def load_module(self, fullname): - kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname) + LOG.debug('SlaveModuleImporter.load_module(%r)', fullname) + ret = self._context.EnqueueAwaitReply(GET_MODULE, None, (fullname,)) + if ret is None: + raise ImportError('Master does not have %r' % (fullname,)) + + kind, path, data = ret + code = compile(zlib.decompress(data), path, 'exec') + module = imp.new_module(fullname) + sys.modules[fullname] = module + eval(code, vars(module), vars(module)) + return module - def GetModule(cls, killed, fullname): - Log('%r.GetModule(%r, %r)', cls, killed, fullname) + +class MasterModuleResponder(object): + def __init__(self, stream): + self._stream = stream + + def GetModule(self, killed, (_, (reply_handle, fullname))): + LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname) if killed: return - if fullname in sys.modules: - pass - GetModule = classmethod(GetModule) + + mod = sys.modules.get(fullname) + if mod: + source = zlib.compress(inspect.getsource(mod)) + path = os.path.abspath(mod.__file__) + self._stream.Enqueue(reply_handle, ('source', path, source)) # # Stream implementations. # -class Stream(object): + +class BasicStream(object): + def fileno(self): + return self._fd + + def Disconnect(self): + LOG.debug('%r: disconnect on %r fd %d', self._broker, self, self._fd) + self._broker.RemoveStream(self) + + def ReadMore(self): + return True + + def WriteMore(self): + return False + + +class Stream(BasicStream): def __init__(self, context): ''' Initialize a new Stream instance. @@ -266,7 +316,7 @@ class Stream(object): self._rhmac = hmac.new(context.key, digestmod=sha.new) self._whmac = self._rhmac.copy() - self._last_handle = 0 + self._last_handle = 1000L self._handle_map = {} self._handle_lock = threading.Lock() @@ -274,7 +324,7 @@ class Stream(object): self._func_ref_lock = threading.Lock() self._pickler_file = cStringIO.StringIO() - self._pickler = cPickle.Pickler(self._pickler_file) + self._pickler = cPickle.Pickler(self._pickler_file, protocol=2) self._pickler.persistent_id = self._CheckFunctionPerID self._unpickler_file = cStringIO.StringIO() @@ -307,7 +357,7 @@ class Stream(object): Returns: object ''' - Log('%r.Unpickle(%r)', self, data) + LOG.debug('%r.Unpickle(%r)', self, data) self._unpickler_file.write(data) self._unpickler_file.seek(0) data = self._unpickler.load() @@ -369,7 +419,8 @@ class Stream(object): handle: long persist: False to only receive a single message. ''' - Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist) + LOG.debug('%r.AddHandleCB(%r, %r, persist=%r)', + self, fn, handle, persist) self._handle_lock.acquire() try: self._handle_map[handle] = persist, fn @@ -381,30 +432,36 @@ class Stream(object): Handle the next complete message on the stream. Raise CorruptMessageError or IOError on failure. ''' - Log('%r.Receive()', self) + LOG.debug('%r.Receive()', self) + + buf = os.read(self._fd, 4096) + if not buf: + return self.Disconnect() - self._input_buf += os.read(self._fd, 4096) + self._input_buf += buf if len(self._input_buf) < 24: return msg_mac = self._input_buf[:20] msg_len = struct.unpack('>L', self._input_buf[20:24])[0] if len(self._input_buf) < msg_len-24: + LOG.debug('Input too short') return self._rhmac.update(self._input_buf[20:msg_len+24]) expected_mac = self._rhmac.digest() if msg_mac != expected_mac: raise CorruptMessageError('%r got invalid MAC: expected %r, got %r', - self, msg_mac.encode('hex'), - expected_mac.encode('hex')) + self, msg_mac.encode('hex'), + expected_mac.encode('hex')) try: handle, data = self.Unpickle(self._input_buf[24:msg_len+24]) - self._input_buf = self._input_buf[msg_len+4:] + self._input_buf = self._input_buf[msg_len+24:] handle = long(handle) - Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data) + LOG.debug('%r.Receive(): decoded handle=%r; data=%r', + self, handle, data) persist, fn = self._handle_map[handle] if not persist: del self._handle_map[handle] @@ -413,6 +470,7 @@ class Stream(object): except (TypeError, ValueError), ex: raise CorruptMessageError('%r got invalid message: %s', self, ex) + LOG.debug('Calling %r (%r, %r)', fn, False, data) fn(False, data) def Transmit(self): @@ -425,9 +483,11 @@ class Stream(object): Raises: IOError ''' - Log('%r.Transmit()', self) + LOG.debug('%r.Transmit()', self) written = os.write(self._fd, self._output_buf[:4096]) self._output_buf = self._output_buf[written:] + + def WriteMore(self): return bool(self._output_buf) def Enqueue(self, handle, obj): @@ -439,7 +499,7 @@ class Stream(object): handle: long obj: object ''' - Log('%r.Enqueue(%r, %r)', self, handle, obj) + LOG.debug('%r.Enqueue(%r, %r)', self, handle, obj) self._output_buf_lock.acquire() try: @@ -449,49 +509,51 @@ class Stream(object): self._output_buf += self._whmac.digest() + msg finally: self._output_buf_lock.release() - self._context.broker.Register(self._context) + self._context.broker.UpdateStream(self, wake=True) def Disconnect(self): ''' Close our associated file descriptor and tell any registered callbacks that the connection has been destroyed. ''' - Log('%r.Disconnect()', self) + LOG.debug('%r.Disconnect()', self) try: os.close(self._fd) except OSError, e: - Log('%r.Disconnect(): did not close fd %s: %s', self, self._fd, e) + LOG.debug('%r.Disconnect(): did not close fd %s: %s', + self, self._fd, e) + + self._fd = None + if self._context.GetStream() is self: + self._context.SetStream(None) for handle, (persist, fn) in self._handle_map.iteritems(): - Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r', - self, handle, fn) + LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r', + self, handle, fn) fn(True, None) @classmethod - def Accept(cls, context, sock): + def Accept(cls, context, fd): ''' ''' stream = cls(context) - stream.sock = sock - stream._fd = sock.fileno() + stream._fd = os.dup(fd) context.SetStream(stream) context.broker.Register(context) + return stream def Connect(self): ''' Connect to a Broker at the address specified in our associated Context. ''' - Log('%r.Connect()', self) + LOG.debug('%r.Connect()', self) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._fd = sock.fileno() sock.connect(self._context.parent_addr) self.Enqueue(0, self._context.name) - def fileno(self): - return self._fd - def __repr__(self): return 'econtext.%s()' %\ (self.__class__.__name__, self._context) @@ -509,9 +571,11 @@ class LocalStream(Stream): def __init__(self, context): super(LocalStream, self).__init__(context) - self._permitted_modules = {} + self._permitted_modules = set(['exceptions']) self._unpickler.find_global = self._FindGlobal - self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE) + + self.responder = MasterModuleResponder(self) + self.AddHandleCB(self.responder.GetModule, handle=GET_MODULE) def _FindGlobal(self, module_name, class_name): ''' @@ -527,7 +591,7 @@ class LocalStream(Stream): ''' if module_name not in self._permitted_modules: raise StreamError('context %r attempted to unpickle %r in module %r', - self._context, class_name, module_name) + self._context, class_name, module_name) return getattr(sys.modules[module_name], class_name) def AllowModule(self, module_name): @@ -561,23 +625,26 @@ class LocalStream(Stream): source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) source = source.replace(' ', '\t') source = source.replace('CONTEXT_NAME', repr(self._context.name)) + encoded = source.encode('base64').replace('\n', '') return [self.python_path, '-c', - 'exec "%s".decode("hex")' % (source.encode('hex'),)] + 'exec "%s".decode("base64")' % (encoded,)] def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._context) def Connect(self): - Log('%r.Connect()', self) + LOG.debug('%r.Connect()', self) pid, sock = CreateChild(*self.GetBootCommand()) self._fd = os.dup(sock.fileno()) sock.close() - Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd) + LOG.debug('%r.Connect(): child process stdin/stdout=%r', self, self._fd) source = inspect.getsource(sys.modules[__name__]) - source += '\nExternalContextMain(%r, %r, %r)\n' %\ - (self._context.name, self._context.broker._listen_addr, - self._context.key) + source += '\nExternalContextMain(%r, %r, %r)\n' % ( + self._context.name, + self._context.broker._listener._listen_addr, + self._context.key + ) compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed @@ -605,7 +672,7 @@ class Context(object): ''' def __init__(self, broker, name=None, hostname=None, username=None, key=None, - parent_addr=None): + parent_addr=None): self.broker = broker self.name = name self.hostname = hostname @@ -626,18 +693,20 @@ class Context(object): optional timeout. The message contains (reply_handle, data), where reply_handle is the handle on which this function expects its reply. ''' - Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data) reply_handle = self._stream.AllocHandle() reply_event = threading.Event() container = [] + LOG.debug('%r.EnqueueAwaitReply(%r, %r, %r) -> reply handle %d', + self, handle, deadline, data, reply_handle) + def _Receive(killed, data): - Log('%r._Receive(%r, %r)', self, killed, data) + LOG.debug('%r._Receive(%r, %r)', self, killed, data) container.extend([killed, data]) reply_event.set() self._stream.AddHandleCB(_Receive, reply_handle, persist=False) - self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data)) + self._stream.Enqueue(handle, (False, (reply_handle,) + data)) reply_event.wait(deadline) if not reply_event.isSet(): @@ -648,21 +717,20 @@ class Context(object): if killed: raise StreamError('lost connection during call.') - Log('%r._EnqueueAwaitReply(): got reply: %r', self, data) + LOG.debug('%r._EnqueueAwaitReply(): got reply: %r', self, data) return data def CallWithDeadline(self, fn, deadline, *args, **kwargs): - Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args, - kwargs) + LOG.debug('%r.CallWithDeadline(%r, %r, *%r, **%r)', + self, fn, deadline, args, kwargs) - use_channel = bool(kwargs.pop('use_channel', False)) if isinstance(fn, types.MethodType) and \ isinstance(fn.im_self, (type, types.ClassType)): fn_class = fn.im_self.__name__ else: fn_class = None - call = (use_channel, fn.__module__, fn_class, fn.__name__, args, kwargs) + call = (fn.__module__, fn_class, fn.__name__, args, kwargs) success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call) if success: @@ -680,6 +748,37 @@ class Context(object): return 'Context(%s)' % ', '.join(bits) +class Waker(BasicStream): + def __init__(self, broker): + self._broker = broker + self._rfd, self._wfd = os.pipe() + self._fd = self._rfd + broker.AddStream(self) + + def Wake(self): + os.write(self._wfd, ' ') + + def Receive(self): + LOG.debug('%r: waking %r', self, self._broker) + os.read(self._rfd, 1) + + +class Listener(BasicStream): + def __init__(self, broker, address=None, backlog=30): + self._broker = broker + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.bind(address or ('0.0.0.0', 0)) + self._sock.listen(backlog) + self._listen_addr = self._sock.getsockname() + self._fd = self._sock.fileno() + broker.AddStream(self) + + def Receive(self): + sock, addr = self._sock.accept() + context = Context(self._broker, name=addr) + Stream.Accept(context, sock.fileno()) + + class Broker(object): ''' Context broker: this is responsible for keeping track of contexts, any @@ -688,16 +787,14 @@ class Broker(object): def __init__(self): self._dead = False - self._poller = select.poll() - self._poller_fd_map = {} - self._poller_lock = threading.Lock() + self._lock = threading.Lock() self._contexts = {} + self._readers = set() + self._writers = set() + self._waker = None + self._waker = Waker(self) - self._wake_rfd, self._wake_wfd = os.pipe() - self._listen_sock = None - self._poller.register(self._wake_rfd) - - self._thread = threading.Thread(target=self.Loop, name='Broker') + self._thread = threading.Thread(target=self._Loop, name='Broker') self._thread.setDaemon(True) self._thread.start() @@ -708,25 +805,41 @@ class Broker(object): address: The IPv4 address tuple to listen on. backlog: Number of connections to accept while broker thread is busy. ''' - self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._listen_sock.bind(address or ('0.0.0.0', 0)) - self._listen_sock.listen(backlog) - self._listen_addr = self._listen_sock.getsockname() - self._poller.register(self._listen_sock) + self._listener = Listener(self, address, backlog) + + def UpdateStream(self, stream, wake=False): + LOG.debug('UpdateStream(%r, wake=%s)', stream, wake) + fileno = stream.fileno() + if fileno is not None and stream.ReadMore(): + self._readers.add(stream) + else: + self._readers.discard(stream) + + if fileno is not None and stream.WriteMore(): + self._writers.add(stream) + else: + self._writers.discard(stream) + + if wake: + self._waker.Wake() + + def AddStream(self, stream): + self._lock.acquire() + try: + if self._waker: + self._waker.Wake() + self.UpdateStream(stream) + finally: + self._lock.release() def Register(self, context): ''' Put a context under control of this broker. ''' - Log('%r.Register(%r) -> fd=%r', self, context, context.GetStream().fileno()) - self._poller_lock.acquire() - os.write(self._wake_wfd, ' ') - try: - self._contexts[context.name] = context - self._poller.register(context.GetStream()) - self._poller_fd_map[context.GetStream().fileno()] = context.GetStream() - finally: - self._poller_lock.release() + LOG.debug('%r.Register(%r) -> fd=%r', self, context, + context.GetStream().fileno()) + self.AddStream(context.GetStream()) + self._contexts[context.name] = context return context def GetLocal(self, name): @@ -754,53 +867,48 @@ class Broker(object): context.SetStream(SSHStream(context)).Connect() return self.Register(context) + def _Loop(self): + try: + self.Loop() + except Exception: + LOG.exception('Loop() crashed') + def Loop(self): ''' Handle stream events until Finalize() is called. ''' while not self._dead: - Log('%r.Loop()', self) - self._poller_lock.acquire() - self._poller_lock.release() - for fd, event in self._poller.poll(): - if fd == self._wake_rfd: - Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd) - os.read(self._wake_rfd, 1) - continue - elif self._listen_sock and fd == self._listen_sock.fileno(): - context = Context(self) - sock, addr = self._listen_sock.accept() - Stream.Accept(context, sock) - continue - - obj = self._poller_fd_map[fd] - if event & select.POLLHUP: - Log('%r: POLLHUP for %d, %r', self, fd, obj) - obj.Disconnect() - elif event & select.POLLIN: - Log('%r: POLLIN for %d, %r', self, fd, obj) - obj.Receive() - elif event & select.POLLOUT: - Log('%r: POLLOUT for %d, %r', self, fd, obj) - if not obj.Transmit(): # If no output buffered, unset POLLOUT. - self._poller.unregister(obj) - self._poller.register(obj, select.POLLIN) - elif event & select.POLLNVAL: - Log('%r: POLLNVAL for %d, %r', self, fd, obj) - obj.Disconnect() - self._poller.unregister(obj) + LOG.debug('%r.Loop()', self) + self._lock.acquire() + self._lock.release() + + #LOG.debug('readers = %r', self._readers) + #LOG.debug('rfds = %r', [r.fileno() for r in self._readers]) + #LOG.debug('writers = %r', self._writers) + rstrms, wstrms, _ = select.select(self._readers, self._writers, ()) + for stream in rstrms: + LOG.debug('%r: POLLIN for %r', self, stream) + stream.Receive() + self.UpdateStream(stream) + + for stream in wstrms: + LOG.debug('%r: POLLOUT for %r', self, stream) + stream.Transmit() + self.UpdateStream(stream) def Finalize(self): ''' Tell all active streams to disconnect. ''' self._dead = True - self._poller_lock.acquire() + self._lock.acquire() try: for name, context in self._contexts.iteritems(): - context.GetStream().Disconnect() + stream = context.GetStream() + if stream: + stream.Disconnect() finally: - self._poller_lock.release() + self._lock.release() def __repr__(self): return 'econtext.Broker()' % (self._contexts.keys(),) @@ -809,25 +917,38 @@ class Broker(object): def ExternalContextMain(context_name, parent_addr, key): syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) - Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) - os.wait() # Reap the first stage. + logging.basicConfig(level=logging.DEBUG) + logging.getLogger('').handlers[0].formatter = Formatter(False) + LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) + + # os.wait() # Reap the first stage. os.dup2(100, 0) os.close(100) broker = Broker() context = Context(broker, 'parent', parent_addr=parent_addr, key=key) - stream = context.SetStream(Stream(context)) - stream.Connect() + stream = Stream.Accept(context, 0) + os.close(0) + + # stream = context.SetStream(Stream(context)) + # stream. + # stream.Connect() broker.Register(context) + importer = SlaveModuleImporter(context) + sys.meta_path.append(importer) + for call_info in Channel(stream, CALL_FUNCTION): - Log('ExternalContextMain(): CALL_FUNCTION %r', call_info) - reply_handle, mod_name, func_name, args, kwargs = call_info - fn = getattr(__import__(mod_name), func_name) + LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info) + (reply_handle, mod_name, class_name, func_name, args, kwargs) = call_info try: + fn = getattr(__import__(mod_name), func_name) stream.Enqueue(reply_handle, (True, fn(*args, **kwargs))) except Exception, e: stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack()))) + + broker.Finalize() + LOG.error('ExternalContextMain exitting') diff --git a/st.py b/st.py deleted file mode 100644 index 50b238ca..00000000 --- a/st.py +++ /dev/null @@ -1,25 +0,0 @@ - -import socket - -def GetCurrentHostname(): - ''' - Fetch the current hostname. - ''' - return socket.gethostname() - - -def LogCurrentUptime(hostname, pathname='/tmp/uptime.txt'): - ''' - Log the current uptime along with process ID that logs it. - - Args: - hostname: the string hostname. - ''' - - fp = file(pathname, 'a') - fp.write('%d %s %s\n' % (os.getpid(), hostname, os.popen('uptime').read())) - fp.close() - - -def try_something_silly(arg): - file('tty', 'w').write('ARG WAS: ' + str(arg) + '\n')