pull/35/head
David Wilson 11 years ago
parent 6db7f23e35
commit 9d0c2139d0

@ -13,6 +13,7 @@ import hmac
import imp import imp
import inspect import inspect
import os import os
import random
import select import select
import sha import sha
import signal import signal
@ -36,6 +37,8 @@ CALL_FUNCTION = 1L
DEBUG = True DEBUG = True
import sys
sys.stderr = open('milf1', 'w', 1)
# #
# Exceptions. # Exceptions.
@ -69,6 +72,16 @@ def Log(fmt, *args):
(fmt % args).replace('econtext.', ''))) (fmt % args).replace('econtext.', '')))
def write_all(fd, s):
written = 0
while written < len(s):
rc = os.write(fd, buffer(s, written))
if not rc:
raise IOError('short write')
written += rc
return written
def CreateChild(*args): def CreateChild(*args):
''' '''
Create a child process whose stdin/stdout is connected to a socket. Create a child process whose stdin/stdout is connected to a socket.
@ -79,15 +92,21 @@ def CreateChild(*args):
Returns: Returns:
pid, sock pid, sock
''' '''
sock1, sock2 = socket.socketpair() parentfp, childfp = socket.socketpair()
pid = os.fork() pid = os.fork()
if not pid: if not pid:
for pair in ((0, sock1), (1, sock2)): os.dup2(childfp.fileno(), 0)
os.dup2(sock2.fileno(), pair[0]) os.dup2(childfp.fileno(), 1)
os.close(pair[1].fileno()) sys.stderr = open('milf2', 'w', 1)
childfp.close()
parentfp.close()
os.execvp(args[0], args) os.execvp(args[0], args)
raise SystemExit raise SystemExit
return pid, sock1
childfp.close()
Log('CreateChild() child %d fd %d, parent %d, args %r',
pid, parentfp.fileno(), os.getpid(), args)
return pid, parentfp
class PartialFunction(object): class PartialFunction(object):
@ -364,7 +383,7 @@ class Stream(object):
''' '''
Log('%r.Receive()', self) Log('%r.Receive()', self)
self._input_buf += os.read(self._rfd, 4096) self._input_buf += os.read(self._fd, 4096)
if len(self._input_buf) < 24: if len(self._input_buf) < 24:
return return
@ -454,8 +473,10 @@ class Stream(object):
''' '''
stream = cls(context) stream = cls(context)
context.SetStream() stream.sock = sock
broker.Register(context) stream._fd = sock.fileno()
context.SetStream(stream)
context.broker.Register(context)
def Connect(self): def Connect(self):
''' '''
@ -540,8 +561,8 @@ class LocalStream(Stream):
source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:]))
source = source.replace(' ', '\t') source = source.replace(' ', '\t')
source = source.replace('CONTEXT_NAME', repr(self._context.name)) source = source.replace('CONTEXT_NAME', repr(self._context.name))
return [ self.python_path, '-c', return [self.python_path, '-c',
'exec "%s".decode("hex")' % (source.encode('hex'),) ] 'exec "%s".decode("hex")' % (source.encode('hex'),)]
def __repr__(self): def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._context) return '%s(%s)' % (self.__class__.__name__, self._context)
@ -549,7 +570,8 @@ class LocalStream(Stream):
def Connect(self): def Connect(self):
Log('%r.Connect()', self) Log('%r.Connect()', self)
pid, sock = CreateChild(*self.GetBootCommand()) pid, sock = CreateChild(*self.GetBootCommand())
self._fd = sock.fileno() self._fd = os.dup(sock.fileno())
sock.close()
Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd) Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
source = inspect.getsource(sys.modules[__name__]) source = inspect.getsource(sys.modules[__name__])
@ -559,7 +581,7 @@ class LocalStream(Stream):
compressed = zlib.compress(source) compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed preamble = str(len(compressed)) + '\n' + compressed
sock.sendall(preamble) write_all(self._fd, preamble)
assert os.read(self._fd, 3) == 'OK\n' assert os.read(self._fd, 3) == 'OK\n'
@ -589,10 +611,7 @@ class Context(object):
self.hostname = hostname self.hostname = hostname
self.username = username self.username = username
self.parent_addr = parent_addr self.parent_addr = parent_addr
if key: self.key = key or ('%016x' % random.getrandbits(128))
self.key = key
else:
self.key = file('/dev/urandom', 'rb').read(16).encode('hex')
def GetStream(self): def GetStream(self):
return self._stream return self._stream
@ -749,8 +768,9 @@ class Broker(object):
os.read(self._wake_rfd, 1) os.read(self._wake_rfd, 1)
continue continue
elif self._listen_sock and fd == self._listen_sock.fileno(): elif self._listen_sock and fd == self._listen_sock.fileno():
context = Context(broker) context = Context(self)
Stream.Accept(context, self._listen_sock.accept()) sock, addr = self._listen_sock.accept()
Stream.Accept(context, sock)
continue continue
obj = self._poller_fd_map[fd] obj = self._poller_fd_map[fd]

Loading…
Cancel
Save