econtext-20070515-1322

pull/35/head
David Wilson 12 years ago
parent 94b94fb838
commit b4bba0357a

16
a.py

@ -0,0 +1,16 @@
import os, socket
def CreateChild(*args):
'''
Create a child process whos stdin/stdout is connected to a socketpair.
Returns:
fd
'''
sock1, sock2 = socket.socketpair()
if os.fork():
for pair in ((0, sock1), (1, sock2)):
os.dup2(sock2.fileno(), pair[0])
os.close(pair[1].fileno())
os.execvp(args[0], args)
raise SystemExit
return sock1

@ -9,13 +9,15 @@ import cPickle
import cStringIO import cStringIO
import commands import commands
import getpass import getpass
import hmac
import imp import imp
import inspect import inspect
import os import os
import select import select
import sha
import signal import signal
import socket
import struct import struct
import subprocess
import sys import sys
import syslog import syslog
import textwrap import textwrap
@ -67,6 +69,23 @@ def Log(fmt, *args):
(fmt%args).replace('econtext.', ''))) (fmt%args).replace('econtext.', '')))
def CreateChild(*args):
'''
Create a child process whos stdin/stdout is connected to a socket.
Returns:
pid, socket
'''
sock1, sock2 = socket.socketpair()
pid = os.fork()
if not pid:
for pair in ((0, sock1), (1, sock2)):
os.dup2(sock2.fileno(), pair[0])
os.close(pair[1].fileno())
os.execvp(args[0], args)
raise SystemExit
return pid, sock1
class PartialFunction(object): class PartialFunction(object):
def __init__(self, fn, *partial_args): def __init__(self, fn, *partial_args):
self.fn = fn self.fn = fn
@ -97,14 +116,17 @@ class Channel(object):
Args: Args:
# Has the Stream object lost its connection? # Has the Stream object lost its connection?
killed: bool killed: bool
# Has the remote Channel had Close() called? / the object passed to the data: (
# remote Send(). # Has the remote Channel had Close() called?
data: (bool, object) bool,
# The object passed to the remote Send()
object
)
''' '''
Log('%r._InternalReceive(%r, %r)', self, killed, data) Log('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue_lock.acquire() self._queue_lock.acquire()
try: try:
self._queue.append((killed or data[0], data[1])) self._queue.append((killed or data[0], killed or data[1]))
self._wake_event.set() self._wake_event.set()
finally: finally:
self._queue_lock.release() self._queue_lock.release()
@ -155,7 +177,7 @@ class Channel(object):
def __iter__(self): def __iter__(self):
''' '''
Return an iterator that yields objects arriving on this channel, until the Return an iterator that yields objects arriving on this channel, until the
channel is closed. channel dies or is closed.
''' '''
while True: while True:
try: try:
@ -194,9 +216,14 @@ class SlaveModuleImporter(object):
def load_module(self, fullname): def load_module(self, fullname):
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname) kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
def GetModule(cls, fullname): def GetModule(cls, killed, fullname):
Log('%r.GetModule(%r, %r)', cls, killed, fullname)
if killed:
return
if fullname in sys.modules: if fullname in sys.modules:
pass pass
GetModule = classmethod(GetModule)
# #
# Stream implementations. # Stream implementations.
@ -205,11 +232,12 @@ class SlaveModuleImporter(object):
class Stream(object): class Stream(object):
def __init__(self, context): def __init__(self, context):
self._context = context self._context = context
self._alive = True
self._input_buf = self._output_buf = '' self._input_buf = self._output_buf = ''
self._input_buf_lock = threading.Lock() self._input_buf_lock = threading.Lock()
self._output_buf_lock = threading.Lock() self._output_buf_lock = threading.Lock()
self._rhmac = hmac.new(context.key, digestmod=sha.new)
self._whmac = self._rhmac.copy()
self._last_handle = 0 self._last_handle = 0
self._handle_map = {} self._handle_map = {}
@ -226,8 +254,6 @@ class Stream(object):
self._unpickler = cPickle.Unpickler(self._unpickler_file) self._unpickler = cPickle.Unpickler(self._unpickler_file)
self._unpickler.persistent_load = self._LoadFunctionFromPerID self._unpickler.persistent_load = self._LoadFunctionFromPerID
# Pickler/Unpickler support.
def Pickle(self, obj): def Pickle(self, obj):
self._pickler.dump(obj) self._pickler.dump(obj)
data = self._pickler_file.getvalue() data = self._pickler_file.getvalue()
@ -275,8 +301,6 @@ class Stream(object):
raise CorruptMessageError('unrecognized persistent ID received: %r', pid) raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
return PartialFunction(self._CallPersistentWhatsit, pid) return PartialFunction(self._CallPersistentWhatsit, pid)
# I/O.
def AllocHandle(self): def AllocHandle(self):
''' '''
Allocate a unique handle for this stream. Allocate a unique handle for this stream.
@ -314,21 +338,25 @@ class Stream(object):
or IOError on failure. or IOError on failure.
''' '''
Log('%r.Receive()', self) Log('%r.Receive()', self)
chunk = os.read(self._rfd, 4096)
if not chunk:
raise StreamError('remote side hung up.')
self._input_buf += chunk self._input_buf += os.read(self._rfd, 4096)
buffer_len = len(self._input_buf) if len(self._input_buf) < 24:
if buffer_len < 4:
return return
msg_len = struct.unpack('>L', self._input_buf[:4])[0] msg_mac = self._input_buf[:20]
if buffer_len < msg_len-4: msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
if len(self._input_buf) < msg_len-24:
return return
self._rhmac.update(self._input_buf[20:msg_len+24])
expected_mac = self._rhmac.digest()
if msg_mac != expected_mac:
raise CorruptMessageError('%r got invalid MAC: expected %r, got %r',
self, msg_mac.encode('hex'),
expected_mac.encode('hex'))
try: try:
handle, data = self.Unpickle(self._input_buf[4:msg_len+4]) handle, data = self.Unpickle(self._input_buf[24:msg_len+24])
self._input_buf = self._input_buf[msg_len+4:] self._input_buf = self._input_buf[msg_len+4:]
handle = long(handle) handle = long(handle)
@ -345,60 +373,63 @@ class Stream(object):
def Transmit(self): def Transmit(self):
''' '''
Transmit pending messages. Raises IOError on failure. Transmit pending messages. Raises IOError on failure. Return value
indicates whether there is still data buffered.
Returns:
bool
''' '''
Log('%r.Transmit()', self) Log('%r.Transmit()', self)
written = os.write(self._wfd, self._output_buf[:4096]) written = os.write(self._fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:] self._output_buf = self._output_buf[written:]
if self._context and not self._output_buf: return bool(self._output_buf)
self._context.manager.UpdateStreamIOState(self)
def Enqueue(self, handle, data):
Log('%r.Enqueue(%r, %r)', self, handle, data)
self._output_buf_lock.acquire()
try:
encoded = self.Pickle((handle, data))
msg = struct.pack('>L', len(encoded)) + encoded
self._whmac.update(msg)
self._output_buf += self._whmac.digest() + msg
finally:
self._output_buf_lock.release()
self._context.broker.Register(self._context)
def Disconnect(self): def Disconnect(self):
''' '''
Called to handle disconnects. Called to handle disconnects.
''' '''
Log('%r.Disconnect()', self) Log('%r.Disconnect()', self)
try:
os.close(self._fd)
except OSError, e:
Log('WARNING: %s', e)
for fd in (self._rfd, self._wfd):
os.close(fd)
# Invoke each registered callback to indicate the connection has been
# destroyed. This prevents pending Channels/RPCs from hanging forever.
for handle, (persist, fn) in self._handle_map.iteritems(): for handle, (persist, fn) in self._handle_map.iteritems():
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r', Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
self, handle, fn) self, handle, fn)
fn(True, None) fn(True, None)
self._context.manager.UpdateStreamIOState(self) @classmethod
def Accept(cls, broker, sock):
context = Context(broker)
stream = cls(context)
context.SetStream(stream)
broker.Register(context)
def GetIOState(self): def Connect(self):
''' '''
Return a 3-tuple describing the instance's I/O state. Connect to a Broker at the address given in the Context instance.
Returns:
(alive, input_fd, output_fd, has_output_buffered)
''' '''
return self._alive, self._rfd, self._wfd, bool(self._output_buf) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._fd = sock.fileno()
def Enqueue(self, handle, data): sock.connect(self._context.parent_addr)
Log('%r.Enqueue(%r, %r)', self, handle, data) self.Enqueue(0, self._context.name)
self._output_buf_lock.acquire() def fileno(self):
try: return self._fd
encoded = self.Pickle((handle, data))
self._output_buf += struct.pack('>L', len(encoded)) + encoded
finally:
self._output_buf_lock.release()
self._context.manager.UpdateStreamIOState(self)
# Misc.
@classmethod
def FromFDs(cls, context, rfd, wfd):
Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd)
self = cls(context)
self._rfd, self._wfd = rfd, wfd
return self
def __repr__(self): def __repr__(self):
return 'econtext.%s(<context=%r>)' %\ return 'econtext.%s(<context=%r>)' %\
@ -415,15 +446,11 @@ class LocalStream(Stream):
lambda self, path: setattr(self, '_python_path', path), lambda self, path: setattr(self, '_python_path', path),
doc='The path to the remote Python interpreter.') doc='The path to the remote Python interpreter.')
def _GetModuleSource(self, killed, name):
if not killed:
return inspect.getsource(sys.modules[name])
def __init__(self, context): def __init__(self, context):
super(LocalStream, self).__init__(context) super(LocalStream, self).__init__(context)
self._permitted_modules = {} self._permitted_modules = {}
self._unpickler.find_global = self._FindGlobal self._unpickler.find_global = self._FindGlobal
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE) self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE)
def _FindGlobal(self, module_name, class_name): def _FindGlobal(self, module_name, class_name):
''' '''
@ -457,8 +484,7 @@ class LocalStream(Stream):
def _FirstStage(): def _FirstStage():
import os,sys,zlib import os,sys,zlib
R,W=os.pipe() R,W=os.pipe()
pid=os.fork() if os.fork():
if pid:
os.dup2(0,100) os.dup2(0,100)
os.dup2(R,0) os.dup2(R,0)
os.close(R) os.close(R)
@ -480,34 +506,21 @@ class LocalStream(Stream):
def __repr__(self): def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._context) return '%s(%s)' % (self.__class__.__name__, self._context)
# Public.
@classmethod
def Accept(cls, fd):
raise NotImplemented
def Connect(self): def Connect(self):
Log('%r.Connect()', self) Log('%r.Connect()', self)
self._child = subprocess.Popen(self.GetBootCommand(), stdin=subprocess.PIPE, pid, sock = CreateChild(*self.GetBootCommand())
stdout=subprocess.PIPE) self._fd = sock.fileno()
self._wfd = self._child.stdin.fileno() Log('%r.Connect(): chlid process stdin/stdout=%r', self, self._fd)
self._rfd = self._child.stdout.fileno()
Log('%r.Connect(): chlid process stdin=%r, stdout=%r',
self, self._wfd, self._rfd)
source = inspect.getsource(sys.modules[__name__]) source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextMain(%r)\n' % (self._context.name,) source += '\nExternalContextMain(%r, %r, %r)\n' %\
(self._context.name, self._context.broker._listen_addr,
self._context.key)
compressed = zlib.compress(source) compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed preamble = str(len(compressed)) + '\n' + compressed
self._child.stdin.write(preamble) sock.sendall(preamble)
self._child.stdin.flush() assert os.read(self._fd, 3) == 'OK\n'
assert os.read(self._rfd, 3) == 'OK\n'
def Disconnect(self):
super(LocalStream, self).Disconnect()
os.kill(self._child.pid, signal.SIGKILL)
class SSHStream(LocalStream): class SSHStream(LocalStream):
@ -526,16 +539,20 @@ class SSHStream(LocalStream):
class Context(object): class Context(object):
''' '''
Represents a remote context regardless of current connection method. Represents a remote context regardless of connection method.
''' '''
def __init__(self, manager, name=None, hostname=None, username=None): def __init__(self, broker, name=None, hostname=None, username=None, key=None,
self.manager = manager parent_addr=None):
self.broker = broker
self.name = name self.name = name
self.hostname = hostname self.hostname = hostname
self.username = username self.username = username
self.tcp_port = None self.parent_addr = parent_addr
self._stream = None if key:
self.key = key
else:
self.key = file('/dev/urandom', 'rb').read(16).encode('hex')
def GetStream(self): def GetStream(self):
return self._stream return self._stream
@ -597,40 +614,45 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits) return 'Context(%s)' % ', '.join(bits)
class ContextManager(object): class Broker(object):
''' '''
Context manager: this is responsible for keeping track of contexts, any Context broker: this is responsible for keeping track of contexts, any
stream that is associated with them, and for I/O multiplexing. stream that is associated with them, and for I/O multiplexing.
''' '''
def __init__(self): def __init__(self):
self._dead = False
self._poller = select.poll() self._poller = select.poll()
self._poller_fd_map = {} self._poller_fd_map = {}
self._poller_lock = threading.Lock()
self._contexts_lock = threading.Lock()
self._contexts = {} self._contexts = {}
self._poller_changes_lock = threading.Lock()
self._poller_changes = {}
self._wake_rfd, self._wake_wfd = os.pipe() self._wake_rfd, self._wake_wfd = os.pipe()
self._poller.register(self._wake_rfd) self._poller.register(self._wake_rfd)
self._thread = threading.Thread(target=self.Loop, name='ContextManager') self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._listen_sock.bind(('0.0.0.0', 0)) # plz 2 allocate 4 me kthx.
self._listen_sock.listen(5)
self._listen_addr = self._listen_sock.getsockname()
self._poller.register(self._listen_sock)
self._thread = threading.Thread(target=self.Loop, name='Broker')
self._thread.setDaemon(True) self._thread.setDaemon(True)
self._thread.start() self._thread.start()
self._dead = False
def Register(self, context): def Register(self, context):
''' '''
Put a context under control of this manager. Put a context under control of this broker.
''' '''
self._contexts_lock.acquire() Log('%r.Register(%r)', self, context)
self._poller_lock.acquire()
os.write(self._wake_wfd, ' ')
try: try:
self._contexts[context.name] = context self._contexts[context.name] = context
self.UpdateStreamIOState(context.GetStream()) self._poller.register(context.GetStream())
self._poller_fd_map[context.GetStream().fileno()] = context.GetStream()
finally: finally:
self._contexts_lock.release() self._poller_lock.release()
return context return context
def GetLocal(self, name): def GetLocal(self, name):
@ -646,132 +668,83 @@ class ContextManager(object):
context.SetStream(LocalStream(context)).Connect() context.SetStream(LocalStream(context)).Connect()
return self.Register(context) return self.Register(context)
def GetRemote(self, hostname, username=None, name=None): def GetRemote(self, hostname, username, name=None):
''' '''
Return the named remote context, or create it if it doesn't exist. Return the named remote context, or create it if it doesn't exist.
''' '''
if username is None:
username = getpass.getuser()
if name is None: if name is None:
name = 'econtext[%s@%s:%d]' %\ name = 'econtext[%s@%s:%d]' %\
(getpass.getuser(), os.getenv('HOSTNAME'), os.getpid()) (username, os.getenv('HOSTNAME'), os.getpid())
context = Context(self, name, hostname, username) context = Context(self, name, hostname, username)
context.SetStream(SSHStream(context)).Connect() context.SetStream(SSHStream(context)).Connect()
return self.Register(context) return self.Register(context)
def UpdateStreamIOState(self, stream):
'''
Update the manager's internal state regarding the specified stream by
marking its FDs for polling as appropriate.
Args:
stream: econtext.Stream
'''
Log('%r.UpdateStreamIOState(%r)', self, stream)
self._poller_changes_lock.acquire()
try:
self._poller_changes[stream] = None
finally:
self._poller_changes_lock.release()
os.write(self._wake_wfd, ' ')
def _DoChangedStreams(self):
'''
Walk the list of streams indicated as having an updated I/O state by
UpdateStreamIOState. Poller registration updates must be done in serial
with calls to its poll() method.
'''
Log('%r._DoChangedStreams()', self)
self._poller_changes_lock.acquire()
try:
changes = self._poller_changes.keys()
self._poller_changes = {}
finally:
self._poller_changes_lock.release()
for stream in changes:
alive, ifd, ofd, has_output = stream.GetIOState()
if not alive:
for fd in (ifd, ofd):
del self._poller_fd_map[fd]
del self._contexts[stream._context]
return
if has_output:
self._poller.register(ofd, select.POLLOUT)
self._poller_fd_map[ofd] = stream
elif ofd in self._poller_fd_map:
self._poller.unregister(ofd)
del self._poller_fd_map[ofd]
self._poller.register(ifd, select.POLLIN)
self._poller_fd_map[ifd] = stream
def Loop(self): def Loop(self):
''' '''
Handle stream events until Finalize() is called. Handle stream events until Finalize() is called.
''' '''
while not self._dead: while not self._dead:
Log('%r.Loop(): %r', self, self._poller_fd_map) Log('%r.Loop()', self)
self._poller_lock.acquire()
self._poller_lock.release()
for fd, event in self._poller.poll(): for fd, event in self._poller.poll():
if fd == self._wake_rfd: if fd == self._wake_rfd:
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd) Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
os.read(self._wake_rfd, 1) os.read(self._wake_rfd, 1)
self._DoChangedStreams() break
elif event & select.POLLHUP: elif fd == self._listen_sock.fileno():
Log('%r: POLLHUP on %d; calling %r', self, fd, Stream.Accept(self, self._listen_sock.accept())
self._poller_fd_map[fd].Disconnect) continue
self._poller_fd_map[fd].Disconnect()
obj = self._poller_fd_map[fd]
if event & select.POLLHUP:
Log('%r: POLLHUP on %r', self, obj)
obj.Disconnect()
elif event & select.POLLIN: elif event & select.POLLIN:
Log('%r: POLLIN on %d; calling %r', self, fd, Log('%r: POLLIN on %r', self, obj)
self._poller_fd_map[fd].Receive) obj.Receive()
self._poller_fd_map[fd].Receive()
elif event & select.POLLOUT: elif event & select.POLLOUT:
Log('%r: POLLOUT on %d', self, fd) Log('%r: POLLOUT on %r', self, obj)
Log('%r: POLLOUT on %d; calling %r', self, fd, if not obj.Transmit(): # If no output buffered, unset POLLOUT.
self._poller_fd_map[fd].Transmit) self._poller.unregister(obj)
self._poller_fd_map[fd].Transmit() self._poller.register(obj, select.POLLIN)
elif event & select.POLLNVAL: elif event & select.POLLNVAL:
Log('%r: POLLNVAL for %d, unregistering it.', self, fd) Log('%r: POLLNVAL for %r', self, obj)
self._poller.unregister(fd) obj.Disconnect()
self._poller.unregister(obj)
def Finalize(self): def Finalize(self):
''' '''
Tell all active streams to disconnect. Tell all active streams to disconnect.
''' '''
self._dead = True self._dead = True
self._contexts_lock.acquire() self._poller_lock.acquire()
try: try:
for name, context in self._contexts.iteritems(): for name, context in self._contexts.iteritems():
stream = context.GetStream() context.GetStream().Disconnect()
if stream:
stream.Disconnect()
finally: finally:
self._contexts_lock.release() self._poller_lock.release()
def __repr__(self): def __repr__(self):
return 'econtext.ContextManager(<contexts=%s>)' % (self._contexts.keys(),) return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
def ExternalContextMain(context_name):
Log('ExternalContextMain(%r)', context_name)
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
def ExternalContextMain(context_name, parent_addr, key):
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
os.wait() # Reap the first stage.
os.dup2(100, 0) os.dup2(100, 0)
os.close(100) os.close(100)
manager = ContextManager() broker = Broker()
context = Context(manager, 'parent') context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = context.SetStream(Stream.FromFDs(context, rfd=0, wfd=1)) stream = context.SetStream(Stream(context))
manager.Register(context) stream.Connect()
broker.Register(context)
for call_info in Channel(stream, CALL_FUNCTION): for call_info in Channel(stream, CALL_FUNCTION):
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info) Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)

Loading…
Cancel
Save