diff --git a/econtext.py b/econtext.py index 9733f792..7bca7a2a 100644 --- a/econtext.py +++ b/econtext.py @@ -13,6 +13,7 @@ import hmac import imp import inspect import os +import random import select import sha import signal @@ -36,6 +37,8 @@ CALL_FUNCTION = 1L DEBUG = True +import sys +sys.stderr = open('milf1', 'w', 1) # # Exceptions. @@ -69,6 +72,16 @@ def Log(fmt, *args): (fmt % args).replace('econtext.', ''))) +def write_all(fd, s): + written = 0 + while written < len(s): + rc = os.write(fd, buffer(s, written)) + if not rc: + raise IOError('short write') + written += rc + return written + + def CreateChild(*args): ''' Create a child process whose stdin/stdout is connected to a socket. @@ -79,15 +92,21 @@ def CreateChild(*args): Returns: pid, sock ''' - sock1, sock2 = socket.socketpair() + parentfp, childfp = socket.socketpair() pid = os.fork() if not pid: - for pair in ((0, sock1), (1, sock2)): - os.dup2(sock2.fileno(), pair[0]) - os.close(pair[1].fileno()) + os.dup2(childfp.fileno(), 0) + os.dup2(childfp.fileno(), 1) + sys.stderr = open('milf2', 'w', 1) + childfp.close() + parentfp.close() os.execvp(args[0], args) raise SystemExit - return pid, sock1 + + childfp.close() + Log('CreateChild() child %d fd %d, parent %d, args %r', + pid, parentfp.fileno(), os.getpid(), args) + return pid, parentfp class PartialFunction(object): @@ -364,7 +383,7 @@ class Stream(object): ''' Log('%r.Receive()', self) - self._input_buf += os.read(self._rfd, 4096) + self._input_buf += os.read(self._fd, 4096) if len(self._input_buf) < 24: return @@ -454,8 +473,10 @@ class Stream(object): ''' stream = cls(context) - context.SetStream() - broker.Register(context) + stream.sock = sock + stream._fd = sock.fileno() + context.SetStream(stream) + context.broker.Register(context) def Connect(self): ''' @@ -540,8 +561,8 @@ class LocalStream(Stream): source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) source = source.replace(' ', '\t') source = source.replace('CONTEXT_NAME', repr(self._context.name)) - return [ self.python_path, '-c', - 'exec "%s".decode("hex")' % (source.encode('hex'),) ] + return [self.python_path, '-c', + 'exec "%s".decode("hex")' % (source.encode('hex'),)] def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._context) @@ -549,7 +570,8 @@ class LocalStream(Stream): def Connect(self): Log('%r.Connect()', self) pid, sock = CreateChild(*self.GetBootCommand()) - self._fd = sock.fileno() + self._fd = os.dup(sock.fileno()) + sock.close() Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd) source = inspect.getsource(sys.modules[__name__]) @@ -559,7 +581,7 @@ class LocalStream(Stream): compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed - sock.sendall(preamble) + write_all(self._fd, preamble) assert os.read(self._fd, 3) == 'OK\n' @@ -589,10 +611,7 @@ class Context(object): self.hostname = hostname self.username = username self.parent_addr = parent_addr - if key: - self.key = key - else: - self.key = file('/dev/urandom', 'rb').read(16).encode('hex') + self.key = key or ('%016x' % random.getrandbits(128)) def GetStream(self): return self._stream @@ -749,8 +768,9 @@ class Broker(object): os.read(self._wake_rfd, 1) continue elif self._listen_sock and fd == self._listen_sock.fileno(): - context = Context(broker) - Stream.Accept(context, self._listen_sock.accept()) + context = Context(self) + sock, addr = self._listen_sock.accept() + Stream.Accept(context, sock) continue obj = self._poller_fd_map[fd]