econtext-20070515-1322

pull/35/head
David Wilson 11 years ago
parent 94b94fb838
commit b4bba0357a

16
a.py

@ -0,0 +1,16 @@
import os, socket
def CreateChild(*args):
'''
Create a child process whos stdin/stdout is connected to a socketpair.
Returns:
fd
'''
sock1, sock2 = socket.socketpair()
if os.fork():
for pair in ((0, sock1), (1, sock2)):
os.dup2(sock2.fileno(), pair[0])
os.close(pair[1].fileno())
os.execvp(args[0], args)
raise SystemExit
return sock1

@ -9,13 +9,15 @@ import cPickle
import cStringIO
import commands
import getpass
import hmac
import imp
import inspect
import os
import select
import sha
import signal
import socket
import struct
import subprocess
import sys
import syslog
import textwrap
@ -67,6 +69,23 @@ def Log(fmt, *args):
(fmt%args).replace('econtext.', '')))
def CreateChild(*args):
'''
Create a child process whos stdin/stdout is connected to a socket.
Returns:
pid, socket
'''
sock1, sock2 = socket.socketpair()
pid = os.fork()
if not pid:
for pair in ((0, sock1), (1, sock2)):
os.dup2(sock2.fileno(), pair[0])
os.close(pair[1].fileno())
os.execvp(args[0], args)
raise SystemExit
return pid, sock1
class PartialFunction(object):
def __init__(self, fn, *partial_args):
self.fn = fn
@ -97,14 +116,17 @@ class Channel(object):
Args:
# Has the Stream object lost its connection?
killed: bool
# Has the remote Channel had Close() called? / the object passed to the
# remote Send().
data: (bool, object)
data: (
# Has the remote Channel had Close() called?
bool,
# The object passed to the remote Send()
object
)
'''
Log('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue_lock.acquire()
try:
self._queue.append((killed or data[0], data[1]))
self._queue.append((killed or data[0], killed or data[1]))
self._wake_event.set()
finally:
self._queue_lock.release()
@ -155,7 +177,7 @@ class Channel(object):
def __iter__(self):
'''
Return an iterator that yields objects arriving on this channel, until the
channel is closed.
channel dies or is closed.
'''
while True:
try:
@ -194,9 +216,14 @@ class SlaveModuleImporter(object):
def load_module(self, fullname):
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
def GetModule(cls, fullname):
def GetModule(cls, killed, fullname):
Log('%r.GetModule(%r, %r)', cls, killed, fullname)
if killed:
return
if fullname in sys.modules:
pass
GetModule = classmethod(GetModule)
#
# Stream implementations.
@ -205,11 +232,12 @@ class SlaveModuleImporter(object):
class Stream(object):
def __init__(self, context):
self._context = context
self._alive = True
self._input_buf = self._output_buf = ''
self._input_buf_lock = threading.Lock()
self._output_buf_lock = threading.Lock()
self._rhmac = hmac.new(context.key, digestmod=sha.new)
self._whmac = self._rhmac.copy()
self._last_handle = 0
self._handle_map = {}
@ -226,8 +254,6 @@ class Stream(object):
self._unpickler = cPickle.Unpickler(self._unpickler_file)
self._unpickler.persistent_load = self._LoadFunctionFromPerID
# Pickler/Unpickler support.
def Pickle(self, obj):
self._pickler.dump(obj)
data = self._pickler_file.getvalue()
@ -275,8 +301,6 @@ class Stream(object):
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
return PartialFunction(self._CallPersistentWhatsit, pid)
# I/O.
def AllocHandle(self):
'''
Allocate a unique handle for this stream.
@ -314,21 +338,25 @@ class Stream(object):
or IOError on failure.
'''
Log('%r.Receive()', self)
chunk = os.read(self._rfd, 4096)
if not chunk:
raise StreamError('remote side hung up.')
self._input_buf += chunk
buffer_len = len(self._input_buf)
if buffer_len < 4:
self._input_buf += os.read(self._rfd, 4096)
if len(self._input_buf) < 24:
return
msg_len = struct.unpack('>L', self._input_buf[:4])[0]
if buffer_len < msg_len-4:
msg_mac = self._input_buf[:20]
msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
if len(self._input_buf) < msg_len-24:
return
self._rhmac.update(self._input_buf[20:msg_len+24])
expected_mac = self._rhmac.digest()
if msg_mac != expected_mac:
raise CorruptMessageError('%r got invalid MAC: expected %r, got %r',
self, msg_mac.encode('hex'),
expected_mac.encode('hex'))
try:
handle, data = self.Unpickle(self._input_buf[4:msg_len+4])
handle, data = self.Unpickle(self._input_buf[24:msg_len+24])
self._input_buf = self._input_buf[msg_len+4:]
handle = long(handle)
@ -345,60 +373,63 @@ class Stream(object):
def Transmit(self):
'''
Transmit pending messages. Raises IOError on failure.
Transmit pending messages. Raises IOError on failure. Return value
indicates whether there is still data buffered.
Returns:
bool
'''
Log('%r.Transmit()', self)
written = os.write(self._wfd, self._output_buf[:4096])
written = os.write(self._fd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
if self._context and not self._output_buf:
self._context.manager.UpdateStreamIOState(self)
return bool(self._output_buf)
def Enqueue(self, handle, data):
Log('%r.Enqueue(%r, %r)', self, handle, data)
self._output_buf_lock.acquire()
try:
encoded = self.Pickle((handle, data))
msg = struct.pack('>L', len(encoded)) + encoded
self._whmac.update(msg)
self._output_buf += self._whmac.digest() + msg
finally:
self._output_buf_lock.release()
self._context.broker.Register(self._context)
def Disconnect(self):
'''
Called to handle disconnects.
'''
Log('%r.Disconnect()', self)
try:
os.close(self._fd)
except OSError, e:
Log('WARNING: %s', e)
for fd in (self._rfd, self._wfd):
os.close(fd)
# Invoke each registered callback to indicate the connection has been
# destroyed. This prevents pending Channels/RPCs from hanging forever.
for handle, (persist, fn) in self._handle_map.iteritems():
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
self, handle, fn)
fn(True, None)
self._context.manager.UpdateStreamIOState(self)
@classmethod
def Accept(cls, broker, sock):
context = Context(broker)
stream = cls(context)
context.SetStream(stream)
broker.Register(context)
def GetIOState(self):
def Connect(self):
'''
Return a 3-tuple describing the instance's I/O state.
Returns:
(alive, input_fd, output_fd, has_output_buffered)
Connect to a Broker at the address given in the Context instance.
'''
return self._alive, self._rfd, self._wfd, bool(self._output_buf)
def Enqueue(self, handle, data):
Log('%r.Enqueue(%r, %r)', self, handle, data)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._fd = sock.fileno()
sock.connect(self._context.parent_addr)
self.Enqueue(0, self._context.name)
self._output_buf_lock.acquire()
try:
encoded = self.Pickle((handle, data))
self._output_buf += struct.pack('>L', len(encoded)) + encoded
finally:
self._output_buf_lock.release()
self._context.manager.UpdateStreamIOState(self)
# Misc.
@classmethod
def FromFDs(cls, context, rfd, wfd):
Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd)
self = cls(context)
self._rfd, self._wfd = rfd, wfd
return self
def fileno(self):
return self._fd
def __repr__(self):
return 'econtext.%s(<context=%r>)' %\
@ -415,15 +446,11 @@ class LocalStream(Stream):
lambda self, path: setattr(self, '_python_path', path),
doc='The path to the remote Python interpreter.')
def _GetModuleSource(self, killed, name):
if not killed:
return inspect.getsource(sys.modules[name])
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_modules = {}
self._unpickler.find_global = self._FindGlobal
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE)
self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE)
def _FindGlobal(self, module_name, class_name):
'''
@ -457,8 +484,7 @@ class LocalStream(Stream):
def _FirstStage():
import os,sys,zlib
R,W=os.pipe()
pid=os.fork()
if pid:
if os.fork():
os.dup2(0,100)
os.dup2(R,0)
os.close(R)
@ -480,34 +506,21 @@ class LocalStream(Stream):
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._context)
# Public.
@classmethod
def Accept(cls, fd):
raise NotImplemented
def Connect(self):
Log('%r.Connect()', self)
self._child = subprocess.Popen(self.GetBootCommand(), stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
self._wfd = self._child.stdin.fileno()
self._rfd = self._child.stdout.fileno()
Log('%r.Connect(): chlid process stdin=%r, stdout=%r',
self, self._wfd, self._rfd)
pid, sock = CreateChild(*self.GetBootCommand())
self._fd = sock.fileno()
Log('%r.Connect(): chlid process stdin/stdout=%r', self, self._fd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextMain(%r)\n' % (self._context.name,)
source += '\nExternalContextMain(%r, %r, %r)\n' %\
(self._context.name, self._context.broker._listen_addr,
self._context.key)
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
self._child.stdin.write(preamble)
self._child.stdin.flush()
assert os.read(self._rfd, 3) == 'OK\n'
def Disconnect(self):
super(LocalStream, self).Disconnect()
os.kill(self._child.pid, signal.SIGKILL)
sock.sendall(preamble)
assert os.read(self._fd, 3) == 'OK\n'
class SSHStream(LocalStream):
@ -526,16 +539,20 @@ class SSHStream(LocalStream):
class Context(object):
'''
Represents a remote context regardless of current connection method.
Represents a remote context regardless of connection method.
'''
def __init__(self, manager, name=None, hostname=None, username=None):
self.manager = manager
def __init__(self, broker, name=None, hostname=None, username=None, key=None,
parent_addr=None):
self.broker = broker
self.name = name
self.hostname = hostname
self.username = username
self.tcp_port = None
self._stream = None
self.parent_addr = parent_addr
if key:
self.key = key
else:
self.key = file('/dev/urandom', 'rb').read(16).encode('hex')
def GetStream(self):
return self._stream
@ -597,40 +614,45 @@ class Context(object):
return 'Context(%s)' % ', '.join(bits)
class ContextManager(object):
class Broker(object):
'''
Context manager: this is responsible for keeping track of contexts, any
Context broker: this is responsible for keeping track of contexts, any
stream that is associated with them, and for I/O multiplexing.
'''
def __init__(self):
self._dead = False
self._poller = select.poll()
self._poller_fd_map = {}
self._contexts_lock = threading.Lock()
self._poller_lock = threading.Lock()
self._contexts = {}
self._poller_changes_lock = threading.Lock()
self._poller_changes = {}
self._wake_rfd, self._wake_wfd = os.pipe()
self._poller.register(self._wake_rfd)
self._thread = threading.Thread(target=self.Loop, name='ContextManager')
self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._listen_sock.bind(('0.0.0.0', 0)) # plz 2 allocate 4 me kthx.
self._listen_sock.listen(5)
self._listen_addr = self._listen_sock.getsockname()
self._poller.register(self._listen_sock)
self._thread = threading.Thread(target=self.Loop, name='Broker')
self._thread.setDaemon(True)
self._thread.start()
self._dead = False
def Register(self, context):
'''
Put a context under control of this manager.
Put a context under control of this broker.
'''
self._contexts_lock.acquire()
Log('%r.Register(%r)', self, context)
self._poller_lock.acquire()
os.write(self._wake_wfd, ' ')
try:
self._contexts[context.name] = context
self.UpdateStreamIOState(context.GetStream())
self._poller.register(context.GetStream())
self._poller_fd_map[context.GetStream().fileno()] = context.GetStream()
finally:
self._contexts_lock.release()
self._poller_lock.release()
return context
def GetLocal(self, name):
@ -646,132 +668,83 @@ class ContextManager(object):
context.SetStream(LocalStream(context)).Connect()
return self.Register(context)
def GetRemote(self, hostname, username=None, name=None):
def GetRemote(self, hostname, username, name=None):
'''
Return the named remote context, or create it if it doesn't exist.
'''
if username is None:
username = getpass.getuser()
if name is None:
name = 'econtext[%s@%s:%d]' %\
(getpass.getuser(), os.getenv('HOSTNAME'), os.getpid())
(username, os.getenv('HOSTNAME'), os.getpid())
context = Context(self, name, hostname, username)
context.SetStream(SSHStream(context)).Connect()
return self.Register(context)
def UpdateStreamIOState(self, stream):
'''
Update the manager's internal state regarding the specified stream by
marking its FDs for polling as appropriate.
Args:
stream: econtext.Stream
'''
Log('%r.UpdateStreamIOState(%r)', self, stream)
self._poller_changes_lock.acquire()
try:
self._poller_changes[stream] = None
finally:
self._poller_changes_lock.release()
os.write(self._wake_wfd, ' ')
def _DoChangedStreams(self):
'''
Walk the list of streams indicated as having an updated I/O state by
UpdateStreamIOState. Poller registration updates must be done in serial
with calls to its poll() method.
'''
Log('%r._DoChangedStreams()', self)
self._poller_changes_lock.acquire()
try:
changes = self._poller_changes.keys()
self._poller_changes = {}
finally:
self._poller_changes_lock.release()
for stream in changes:
alive, ifd, ofd, has_output = stream.GetIOState()
if not alive:
for fd in (ifd, ofd):
del self._poller_fd_map[fd]
del self._contexts[stream._context]
return
if has_output:
self._poller.register(ofd, select.POLLOUT)
self._poller_fd_map[ofd] = stream
elif ofd in self._poller_fd_map:
self._poller.unregister(ofd)
del self._poller_fd_map[ofd]
self._poller.register(ifd, select.POLLIN)
self._poller_fd_map[ifd] = stream
def Loop(self):
'''
Handle stream events until Finalize() is called.
'''
while not self._dead:
Log('%r.Loop(): %r', self, self._poller_fd_map)
Log('%r.Loop()', self)
self._poller_lock.acquire()
self._poller_lock.release()
for fd, event in self._poller.poll():
if fd == self._wake_rfd:
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
os.read(self._wake_rfd, 1)
self._DoChangedStreams()
elif event & select.POLLHUP:
Log('%r: POLLHUP on %d; calling %r', self, fd,
self._poller_fd_map[fd].Disconnect)
self._poller_fd_map[fd].Disconnect()
break
elif fd == self._listen_sock.fileno():
Stream.Accept(self, self._listen_sock.accept())
continue
obj = self._poller_fd_map[fd]
if event & select.POLLHUP:
Log('%r: POLLHUP on %r', self, obj)
obj.Disconnect()
elif event & select.POLLIN:
Log('%r: POLLIN on %d; calling %r', self, fd,
self._poller_fd_map[fd].Receive)
self._poller_fd_map[fd].Receive()
Log('%r: POLLIN on %r', self, obj)
obj.Receive()
elif event & select.POLLOUT:
Log('%r: POLLOUT on %d', self, fd)
Log('%r: POLLOUT on %d; calling %r', self, fd,
self._poller_fd_map[fd].Transmit)
self._poller_fd_map[fd].Transmit()
Log('%r: POLLOUT on %r', self, obj)
if not obj.Transmit(): # If no output buffered, unset POLLOUT.
self._poller.unregister(obj)
self._poller.register(obj, select.POLLIN)
elif event & select.POLLNVAL:
Log('%r: POLLNVAL for %d, unregistering it.', self, fd)
self._poller.unregister(fd)
Log('%r: POLLNVAL for %r', self, obj)
obj.Disconnect()
self._poller.unregister(obj)
def Finalize(self):
'''
Tell all active streams to disconnect.
'''
self._dead = True
self._contexts_lock.acquire()
self._poller_lock.acquire()
try:
for name, context in self._contexts.iteritems():
stream = context.GetStream()
if stream:
stream.Disconnect()
context.GetStream().Disconnect()
finally:
self._contexts_lock.release()
self._poller_lock.release()
def __repr__(self):
return 'econtext.ContextManager(<contexts=%s>)' % (self._contexts.keys(),)
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
def ExternalContextMain(context_name):
Log('ExternalContextMain(%r)', context_name)
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
def ExternalContextMain(context_name, parent_addr, key):
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
os.wait() # Reap the first stage.
os.dup2(100, 0)
os.close(100)
manager = ContextManager()
context = Context(manager, 'parent')
broker = Broker()
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
stream = context.SetStream(Stream.FromFDs(context, rfd=0, wfd=1))
manager.Register(context)
stream = context.SetStream(Stream(context))
stream.Connect()
broker.Register(context)
for call_info in Channel(stream, CALL_FUNCTION):
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)

Loading…
Cancel
Save