diff --git a/a.py b/a.py new file mode 100644 index 00000000..004354be --- /dev/null +++ b/a.py @@ -0,0 +1,16 @@ +import os, socket + +def CreateChild(*args): + ''' + Create a child process whos stdin/stdout is connected to a socketpair. + Returns: + fd + ''' + sock1, sock2 = socket.socketpair() + if os.fork(): + for pair in ((0, sock1), (1, sock2)): + os.dup2(sock2.fileno(), pair[0]) + os.close(pair[1].fileno()) + os.execvp(args[0], args) + raise SystemExit + return sock1 diff --git a/econtext.py b/econtext.py index a6789453..20547629 100755 --- a/econtext.py +++ b/econtext.py @@ -9,13 +9,15 @@ import cPickle import cStringIO import commands import getpass +import hmac import imp import inspect import os import select +import sha import signal +import socket import struct -import subprocess import sys import syslog import textwrap @@ -67,6 +69,23 @@ def Log(fmt, *args): (fmt%args).replace('econtext.', ''))) +def CreateChild(*args): + ''' + Create a child process whos stdin/stdout is connected to a socket. + Returns: + pid, socket + ''' + sock1, sock2 = socket.socketpair() + pid = os.fork() + if not pid: + for pair in ((0, sock1), (1, sock2)): + os.dup2(sock2.fileno(), pair[0]) + os.close(pair[1].fileno()) + os.execvp(args[0], args) + raise SystemExit + return pid, sock1 + + class PartialFunction(object): def __init__(self, fn, *partial_args): self.fn = fn @@ -97,14 +116,17 @@ class Channel(object): Args: # Has the Stream object lost its connection? killed: bool - # Has the remote Channel had Close() called? / the object passed to the - # remote Send(). - data: (bool, object) + data: ( + # Has the remote Channel had Close() called? + bool, + # The object passed to the remote Send() + object + ) ''' Log('%r._InternalReceive(%r, %r)', self, killed, data) self._queue_lock.acquire() try: - self._queue.append((killed or data[0], data[1])) + self._queue.append((killed or data[0], killed or data[1])) self._wake_event.set() finally: self._queue_lock.release() @@ -155,7 +177,7 @@ class Channel(object): def __iter__(self): ''' Return an iterator that yields objects arriving on this channel, until the - channel is closed. + channel dies or is closed. ''' while True: try: @@ -194,9 +216,14 @@ class SlaveModuleImporter(object): def load_module(self, fullname): kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname) - def GetModule(cls, fullname): + def GetModule(cls, killed, fullname): + Log('%r.GetModule(%r, %r)', cls, killed, fullname) + if killed: + return if fullname in sys.modules: pass + GetModule = classmethod(GetModule) + # # Stream implementations. @@ -205,11 +232,12 @@ class SlaveModuleImporter(object): class Stream(object): def __init__(self, context): self._context = context - self._alive = True self._input_buf = self._output_buf = '' self._input_buf_lock = threading.Lock() self._output_buf_lock = threading.Lock() + self._rhmac = hmac.new(context.key, digestmod=sha.new) + self._whmac = self._rhmac.copy() self._last_handle = 0 self._handle_map = {} @@ -226,8 +254,6 @@ class Stream(object): self._unpickler = cPickle.Unpickler(self._unpickler_file) self._unpickler.persistent_load = self._LoadFunctionFromPerID - # Pickler/Unpickler support. - def Pickle(self, obj): self._pickler.dump(obj) data = self._pickler_file.getvalue() @@ -275,8 +301,6 @@ class Stream(object): raise CorruptMessageError('unrecognized persistent ID received: %r', pid) return PartialFunction(self._CallPersistentWhatsit, pid) - # I/O. - def AllocHandle(self): ''' Allocate a unique handle for this stream. @@ -314,21 +338,25 @@ class Stream(object): or IOError on failure. ''' Log('%r.Receive()', self) - chunk = os.read(self._rfd, 4096) - if not chunk: - raise StreamError('remote side hung up.') - self._input_buf += chunk - buffer_len = len(self._input_buf) - if buffer_len < 4: + self._input_buf += os.read(self._rfd, 4096) + if len(self._input_buf) < 24: return - msg_len = struct.unpack('>L', self._input_buf[:4])[0] - if buffer_len < msg_len-4: + msg_mac = self._input_buf[:20] + msg_len = struct.unpack('>L', self._input_buf[20:24])[0] + if len(self._input_buf) < msg_len-24: return + self._rhmac.update(self._input_buf[20:msg_len+24]) + expected_mac = self._rhmac.digest() + if msg_mac != expected_mac: + raise CorruptMessageError('%r got invalid MAC: expected %r, got %r', + self, msg_mac.encode('hex'), + expected_mac.encode('hex')) + try: - handle, data = self.Unpickle(self._input_buf[4:msg_len+4]) + handle, data = self.Unpickle(self._input_buf[24:msg_len+24]) self._input_buf = self._input_buf[msg_len+4:] handle = long(handle) @@ -345,60 +373,63 @@ class Stream(object): def Transmit(self): ''' - Transmit pending messages. Raises IOError on failure. + Transmit pending messages. Raises IOError on failure. Return value + indicates whether there is still data buffered. + + Returns: + bool ''' Log('%r.Transmit()', self) - written = os.write(self._wfd, self._output_buf[:4096]) + written = os.write(self._fd, self._output_buf[:4096]) self._output_buf = self._output_buf[written:] - if self._context and not self._output_buf: - self._context.manager.UpdateStreamIOState(self) + return bool(self._output_buf) + + def Enqueue(self, handle, data): + Log('%r.Enqueue(%r, %r)', self, handle, data) + + self._output_buf_lock.acquire() + try: + encoded = self.Pickle((handle, data)) + msg = struct.pack('>L', len(encoded)) + encoded + self._whmac.update(msg) + self._output_buf += self._whmac.digest() + msg + finally: + self._output_buf_lock.release() + self._context.broker.Register(self._context) def Disconnect(self): ''' Called to handle disconnects. ''' Log('%r.Disconnect()', self) + try: + os.close(self._fd) + except OSError, e: + Log('WARNING: %s', e) - for fd in (self._rfd, self._wfd): - os.close(fd) - - # Invoke each registered callback to indicate the connection has been - # destroyed. This prevents pending Channels/RPCs from hanging forever. for handle, (persist, fn) in self._handle_map.iteritems(): Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r', self, handle, fn) fn(True, None) - self._context.manager.UpdateStreamIOState(self) + @classmethod + def Accept(cls, broker, sock): + context = Context(broker) + stream = cls(context) + context.SetStream(stream) + broker.Register(context) - def GetIOState(self): + def Connect(self): ''' - Return a 3-tuple describing the instance's I/O state. - - Returns: - (alive, input_fd, output_fd, has_output_buffered) + Connect to a Broker at the address given in the Context instance. ''' - return self._alive, self._rfd, self._wfd, bool(self._output_buf) - - def Enqueue(self, handle, data): - Log('%r.Enqueue(%r, %r)', self, handle, data) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._fd = sock.fileno() + sock.connect(self._context.parent_addr) + self.Enqueue(0, self._context.name) - self._output_buf_lock.acquire() - try: - encoded = self.Pickle((handle, data)) - self._output_buf += struct.pack('>L', len(encoded)) + encoded - finally: - self._output_buf_lock.release() - self._context.manager.UpdateStreamIOState(self) - - # Misc. - - @classmethod - def FromFDs(cls, context, rfd, wfd): - Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd) - self = cls(context) - self._rfd, self._wfd = rfd, wfd - return self + def fileno(self): + return self._fd def __repr__(self): return 'econtext.%s()' %\ @@ -415,15 +446,11 @@ class LocalStream(Stream): lambda self, path: setattr(self, '_python_path', path), doc='The path to the remote Python interpreter.') - def _GetModuleSource(self, killed, name): - if not killed: - return inspect.getsource(sys.modules[name]) - def __init__(self, context): super(LocalStream, self).__init__(context) self._permitted_modules = {} self._unpickler.find_global = self._FindGlobal - self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE) + self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE) def _FindGlobal(self, module_name, class_name): ''' @@ -457,8 +484,7 @@ class LocalStream(Stream): def _FirstStage(): import os,sys,zlib R,W=os.pipe() - pid=os.fork() - if pid: + if os.fork(): os.dup2(0,100) os.dup2(R,0) os.close(R) @@ -480,34 +506,21 @@ class LocalStream(Stream): def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._context) - # Public. - - @classmethod - def Accept(cls, fd): - raise NotImplemented - def Connect(self): Log('%r.Connect()', self) - self._child = subprocess.Popen(self.GetBootCommand(), stdin=subprocess.PIPE, - stdout=subprocess.PIPE) - self._wfd = self._child.stdin.fileno() - self._rfd = self._child.stdout.fileno() - Log('%r.Connect(): chlid process stdin=%r, stdout=%r', - self, self._wfd, self._rfd) + pid, sock = CreateChild(*self.GetBootCommand()) + self._fd = sock.fileno() + Log('%r.Connect(): chlid process stdin/stdout=%r', self, self._fd) source = inspect.getsource(sys.modules[__name__]) - source += '\nExternalContextMain(%r)\n' % (self._context.name,) + source += '\nExternalContextMain(%r, %r, %r)\n' %\ + (self._context.name, self._context.broker._listen_addr, + self._context.key) compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed - self._child.stdin.write(preamble) - self._child.stdin.flush() - - assert os.read(self._rfd, 3) == 'OK\n' - - def Disconnect(self): - super(LocalStream, self).Disconnect() - os.kill(self._child.pid, signal.SIGKILL) + sock.sendall(preamble) + assert os.read(self._fd, 3) == 'OK\n' class SSHStream(LocalStream): @@ -526,16 +539,20 @@ class SSHStream(LocalStream): class Context(object): ''' - Represents a remote context regardless of current connection method. + Represents a remote context regardless of connection method. ''' - def __init__(self, manager, name=None, hostname=None, username=None): - self.manager = manager + def __init__(self, broker, name=None, hostname=None, username=None, key=None, + parent_addr=None): + self.broker = broker self.name = name self.hostname = hostname self.username = username - self.tcp_port = None - self._stream = None + self.parent_addr = parent_addr + if key: + self.key = key + else: + self.key = file('/dev/urandom', 'rb').read(16).encode('hex') def GetStream(self): return self._stream @@ -597,40 +614,45 @@ class Context(object): return 'Context(%s)' % ', '.join(bits) -class ContextManager(object): +class Broker(object): ''' - Context manager: this is responsible for keeping track of contexts, any + Context broker: this is responsible for keeping track of contexts, any stream that is associated with them, and for I/O multiplexing. ''' def __init__(self): + self._dead = False self._poller = select.poll() self._poller_fd_map = {} - - self._contexts_lock = threading.Lock() + self._poller_lock = threading.Lock() self._contexts = {} - self._poller_changes_lock = threading.Lock() - self._poller_changes = {} - self._wake_rfd, self._wake_wfd = os.pipe() self._poller.register(self._wake_rfd) - self._thread = threading.Thread(target=self.Loop, name='ContextManager') + self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._listen_sock.bind(('0.0.0.0', 0)) # plz 2 allocate 4 me kthx. + self._listen_sock.listen(5) + self._listen_addr = self._listen_sock.getsockname() + self._poller.register(self._listen_sock) + + self._thread = threading.Thread(target=self.Loop, name='Broker') self._thread.setDaemon(True) self._thread.start() - self._dead = False def Register(self, context): ''' - Put a context under control of this manager. + Put a context under control of this broker. ''' - self._contexts_lock.acquire() + Log('%r.Register(%r)', self, context) + self._poller_lock.acquire() + os.write(self._wake_wfd, ' ') try: self._contexts[context.name] = context - self.UpdateStreamIOState(context.GetStream()) + self._poller.register(context.GetStream()) + self._poller_fd_map[context.GetStream().fileno()] = context.GetStream() finally: - self._contexts_lock.release() + self._poller_lock.release() return context def GetLocal(self, name): @@ -646,132 +668,83 @@ class ContextManager(object): context.SetStream(LocalStream(context)).Connect() return self.Register(context) - def GetRemote(self, hostname, username=None, name=None): + def GetRemote(self, hostname, username, name=None): ''' Return the named remote context, or create it if it doesn't exist. ''' - if username is None: - username = getpass.getuser() if name is None: name = 'econtext[%s@%s:%d]' %\ - (getpass.getuser(), os.getenv('HOSTNAME'), os.getpid()) + (username, os.getenv('HOSTNAME'), os.getpid()) context = Context(self, name, hostname, username) context.SetStream(SSHStream(context)).Connect() return self.Register(context) - def UpdateStreamIOState(self, stream): - ''' - Update the manager's internal state regarding the specified stream by - marking its FDs for polling as appropriate. - - Args: - stream: econtext.Stream - ''' - Log('%r.UpdateStreamIOState(%r)', self, stream) - - self._poller_changes_lock.acquire() - try: - self._poller_changes[stream] = None - finally: - self._poller_changes_lock.release() - os.write(self._wake_wfd, ' ') - - def _DoChangedStreams(self): - ''' - Walk the list of streams indicated as having an updated I/O state by - UpdateStreamIOState. Poller registration updates must be done in serial - with calls to its poll() method. - ''' - Log('%r._DoChangedStreams()', self) - - self._poller_changes_lock.acquire() - try: - changes = self._poller_changes.keys() - self._poller_changes = {} - finally: - self._poller_changes_lock.release() - - for stream in changes: - alive, ifd, ofd, has_output = stream.GetIOState() - - if not alive: - for fd in (ifd, ofd): - del self._poller_fd_map[fd] - del self._contexts[stream._context] - return - - if has_output: - self._poller.register(ofd, select.POLLOUT) - self._poller_fd_map[ofd] = stream - elif ofd in self._poller_fd_map: - self._poller.unregister(ofd) - del self._poller_fd_map[ofd] - - self._poller.register(ifd, select.POLLIN) - self._poller_fd_map[ifd] = stream - def Loop(self): ''' Handle stream events until Finalize() is called. ''' while not self._dead: - Log('%r.Loop(): %r', self, self._poller_fd_map) + Log('%r.Loop()', self) + self._poller_lock.acquire() + self._poller_lock.release() for fd, event in self._poller.poll(): if fd == self._wake_rfd: Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd) os.read(self._wake_rfd, 1) - self._DoChangedStreams() - elif event & select.POLLHUP: - Log('%r: POLLHUP on %d; calling %r', self, fd, - self._poller_fd_map[fd].Disconnect) - self._poller_fd_map[fd].Disconnect() + break + elif fd == self._listen_sock.fileno(): + Stream.Accept(self, self._listen_sock.accept()) + continue + + obj = self._poller_fd_map[fd] + if event & select.POLLHUP: + Log('%r: POLLHUP on %r', self, obj) + obj.Disconnect() elif event & select.POLLIN: - Log('%r: POLLIN on %d; calling %r', self, fd, - self._poller_fd_map[fd].Receive) - self._poller_fd_map[fd].Receive() + Log('%r: POLLIN on %r', self, obj) + obj.Receive() elif event & select.POLLOUT: - Log('%r: POLLOUT on %d', self, fd) - Log('%r: POLLOUT on %d; calling %r', self, fd, - self._poller_fd_map[fd].Transmit) - self._poller_fd_map[fd].Transmit() + Log('%r: POLLOUT on %r', self, obj) + if not obj.Transmit(): # If no output buffered, unset POLLOUT. + self._poller.unregister(obj) + self._poller.register(obj, select.POLLIN) elif event & select.POLLNVAL: - Log('%r: POLLNVAL for %d, unregistering it.', self, fd) - self._poller.unregister(fd) + Log('%r: POLLNVAL for %r', self, obj) + obj.Disconnect() + self._poller.unregister(obj) def Finalize(self): ''' Tell all active streams to disconnect. ''' self._dead = True - self._contexts_lock.acquire() + self._poller_lock.acquire() try: for name, context in self._contexts.iteritems(): - stream = context.GetStream() - if stream: - stream.Disconnect() + context.GetStream().Disconnect() finally: - self._contexts_lock.release() + self._poller_lock.release() def __repr__(self): - return 'econtext.ContextManager()' % (self._contexts.keys(),) - + return 'econtext.Broker()' % (self._contexts.keys(),) -def ExternalContextMain(context_name): - Log('ExternalContextMain(%r)', context_name) - assert os.wait()[1] == 0, 'first stage did not exit cleanly.' +def ExternalContextMain(context_name, parent_addr, key): syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) + Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) + os.wait() # Reap the first stage. os.dup2(100, 0) os.close(100) - manager = ContextManager() - context = Context(manager, 'parent') + broker = Broker() + context = Context(broker, 'parent', parent_addr=parent_addr, key=key) - stream = context.SetStream(Stream.FromFDs(context, rfd=0, wfd=1)) - manager.Register(context) + stream = context.SetStream(Stream(context)) + stream.Connect() + broker.Register(context) for call_info in Channel(stream, CALL_FUNCTION): Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)