econtext-20070515-0811

pull/35/head
David Wilson 11 years ago
parent b35689bfe8
commit 94b94fb838

@ -12,17 +12,14 @@ import getpass
import imp
import inspect
import os
import sched
import select
import signal
import socket
import struct
import subprocess
import sys
import syslog
import textwrap
import threading
import time
import traceback
import types
import zlib
@ -32,12 +29,9 @@ import zlib
# Module-level data.
#
GET_MODULE_SOURCE = 0L
GET_MODULE = 0L
CALL_FUNCTION = 1L
_manager = None
_manager_thread = None
DEBUG = True
@ -50,6 +44,9 @@ class ContextError(Exception):
def __init__(self, fmt, *args):
Exception.__init__(self, fmt % args)
class ChannelError(ContextError):
'Raised when a channel dies or has been closed.'
class StreamError(ContextError):
'Raised when a stream cannot be established.'
@ -82,49 +79,132 @@ class PartialFunction(object):
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
class FunctionProxy(object):
__slots__ = ['_context', '_per_id']
class Channel(object):
def __init__(self, stream, handle):
self._stream = stream
self._handle = handle
self._wake_event = threading.Event()
self._queue_lock = threading.Lock()
self._queue = []
self._stream.AddHandleCB(self._InternalReceive, handle)
def __init__(self, context, per_id):
self._context = context
self._per_id = per_id
def _InternalReceive(self, killed, data):
'''
Callback from the stream object; appends a tuple of
(killed-or-closed, data) to the internal queue and wakes the internal
event.
def __call__(self, *args, **kwargs):
self._context._Call(self._per_id, args, kwargs)
Args:
# Has the Stream object lost its connection?
killed: bool
# Has the remote Channel had Close() called? / the object passed to the
# remote Send().
data: (bool, object)
'''
Log('%r._InternalReceive(%r, %r)', self, killed, data)
self._queue_lock.acquire()
try:
self._queue.append((killed or data[0], data[1]))
self._wake_event.set()
finally:
self._queue_lock.release()
def Close(self):
'''
Indicate this channel is closed to the remote side.
'''
Log('%r.Close()', self)
self._stream.Enqueue(handle, (True, None))
def Send(self, data):
'''
Send the given object to the remote side.
'''
Log('%r.Send(%r)', self, data)
self._stream.Enqueue(handle, (False, data))
def Receive(self, timeout=None):
'''
Receive the next object to arrive on this channel, or return if the
optional timeout is reached.
Args:
timeout: float
Returns:
object
'''
Log('%r.Receive(%r)', self, timeout)
if not self._queue:
self._wake_event.wait(timeout)
if not self._wake_event.isSet():
return
self._queue_lock.acquire()
try:
self._wake_event.clear()
Log('%r.Receive() queue is %r', self, self._queue)
closed, data = self._queue.pop(0)
Log('%r.Receive() got closed=%r, data=%r', self, closed, data)
if closed:
raise ChannelError('Channel is closed.')
return data
finally:
self._queue_lock.release()
def __iter__(self):
'''
Return an iterator that yields objects arriving on this channel, until the
channel is closed.
'''
while True:
try:
yield self.Receive()
except ChannelError:
return
def __repr__(self):
return 'econtext.Channel(%r, %r)' % (self._stream, self._handle)
class SlaveModuleImporter(object):
'''
This objects implements the import hook protocol defined in
http://www.python.org/dev/peps/pep-0302/; the interpreter will ask it if it
knows how to load each module, it will in turn ask the interpreter if it
knows how to do the load, and if so, it will say it can't. This round about
crap is necessary because the module import mechanism is brutal.
This implements the import protocol described in PEP 302. It works like so:
- Python asks it if it can import a module.
- It asks Python (via imp module) if it can import the module.
- If Python says yes, it says no.
- If Python says no, it asks the parent context for the module.
- If the module isn't returned by the parent, asplode, otherwise ask Python
to load the returned module.
When the built in importer can't load a module, we try requesting it from the
parent context.
This roundabout crap is necessary because the built-in importer is tried only
after custom hooks are. A class method is provided for the parent context to
satisfy the module request; it will only return modules that have been loaded
in the parent context.
'''
def __init__(self, context):
self._context = context
def find_module(self, fullname, path=None):
if imp.find_module(fullname):
return
if not imp.find_module(fullname):
return self
def load_module(self, fullname):
kind, data = self._context.
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
def GetModule(cls, fullname):
if fullname in sys.modules:
pass
#
# Stream implementations.
#
class Stream(object):
def __init__(self, context, secure_unpickler=True):
def __init__(self, context):
self._context = context
self._sched_id = 0.0
self._alive = True
self._input_buf = self._output_buf = ''
@ -146,12 +226,24 @@ class Stream(object):
self._unpickler = cPickle.Unpickler(self._unpickler_file)
self._unpickler.persistent_load = self._LoadFunctionFromPerID
if secure_unpickler:
self._permitted_modules = {}
self._unpickler.find_global = self._FindGlobal
# Pickler/Unpickler support.
def Pickle(self, obj):
self._pickler.dump(obj)
data = self._pickler_file.getvalue()
self._pickler_file.seek(0)
self._pickler_file.truncate(0)
return data
def Unpickle(self, data):
Log('%r.Unpickle(%r)', self, data)
self._unpickler_file.write(data)
self._unpickler_file.seek(0)
data = self._unpickler.load()
self._unpickler_file.seek(0)
self._unpickler_file.truncate(0)
return data
def _CheckFunctionPerID(self, obj):
'''
Please see the cPickle documentation. Given an object, return None
@ -163,7 +255,6 @@ class Stream(object):
Returns:
str
'''
if isinstance(obj, (types.FunctionType, types.MethodType)):
pid = 'FUNC:' + repr(obj)
self._func_refs[per_id] = obj
@ -180,48 +271,19 @@ class Stream(object):
Returns:
object
'''
if not pid.startswith('FUNC:'):
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
return FunctionProxy(self, pid)
def _FindGlobal(self, module_name, class_name):
'''
Please see the cPickle documentation. Given a module and class name,
determine whether class referred to is safe for unpickling.
Args:
module_name: str
class_name: str
Returns:
classobj or type
'''
if module_name not in self._permitted_modules:
raise StreamError('context %r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowModule(self, module_name):
'''
Add the given module to the list of permitted modules.
Args:
module_name: str
'''
self._permitted_modules.add(module_name)
return PartialFunction(self._CallPersistentWhatsit, pid)
# I/O.
def AllocHandle(self):
'''
Allocate a unique communications handle for this stream.
Allocate a unique handle for this stream.
Returns:
long
'''
self._handle_lock.acquire()
try:
self._last_handle += 1L
@ -239,7 +301,6 @@ class Stream(object):
handle: long
persist: bool
'''
Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist)
self._handle_lock.acquire()
try:
@ -252,7 +313,7 @@ class Stream(object):
Handle the next complete message on the stream. Raise CorruptMessageError
or IOError on failure.
'''
Log('%r.Receive()', self)
chunk = os.read(self._rfd, 4096)
if not chunk:
raise StreamError('remote side hung up.')
@ -266,12 +327,8 @@ class Stream(object):
if buffer_len < msg_len-4:
return
Log('%r.Receive() -> msg_len=%d; msg=%r', self, msg_len,
self._input_buf[4:msg_len+4])
try:
# TODO: wire in the per-instance unpickler.
handle, data = cPickle.loads(self._input_buf[4:msg_len+4])
handle, data = self.Unpickle(self._input_buf[4:msg_len+4])
self._input_buf = self._input_buf[msg_len+4:]
handle = long(handle)
@ -284,13 +341,13 @@ class Stream(object):
except (TypeError, ValueError), ex:
raise CorruptMessageError('%r got invalid message: %s', self, ex)
fn(handle, False, data)
fn(False, data)
def Transmit(self):
'''
Transmit pending messages. Raises IOError on failure.
'''
Log('%r.Transmit()', self)
written = os.write(self._wfd, self._output_buf[:4096])
self._output_buf = self._output_buf[written:]
if self._context and not self._output_buf:
@ -300,24 +357,17 @@ class Stream(object):
'''
Called to handle disconnects.
'''
Log('%r.Disconnect()', self)
for fd in (self._rfd, self._wfd):
try:
os.close(fd)
Log('%r.Disconnect(): closed fd %d', self, fd)
except OSError:
pass
# Invoke each registered non-persistent handle callback to indicate the
# connection has been destroyed. This prevents pending RPCs from hanging
# infinitely.
# Invoke each registered callback to indicate the connection has been
# destroyed. This prevents pending Channels/RPCs from hanging forever.
for handle, (persist, fn) in self._handle_map.iteritems():
if not persist:
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
self, handle, fn)
fn(handle, True, None)
fn(True, None)
self._context.manager.UpdateStreamIOState(self)
@ -328,8 +378,6 @@ class Stream(object):
Returns:
(alive, input_fd, output_fd, has_output_buffered)
'''
# TODO: this alive flag is stupid.
return self._alive, self._rfd, self._wfd, bool(self._output_buf)
def Enqueue(self, handle, data):
@ -337,8 +385,7 @@ class Stream(object):
self._output_buf_lock.acquire()
try:
# TODO: wire in the per-instance pickler.
encoded = cPickle.dumps((handle, data))
encoded = self.Pickle((handle, data))
self._output_buf += struct.pack('>L', len(encoded)) + encoded
finally:
self._output_buf_lock.release()
@ -346,55 +393,63 @@ class Stream(object):
# Misc.
@classmethod
def FromFDs(cls, context, rfd, wfd):
Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd)
self = cls(context)
self._rfd, self._wfd = rfd, wfd
return self
FromFDs = classmethod(FromFDs)
def __repr__(self):
return 'econtext.%s(<context=%r>)' %\
(self.__class__.__name__, self._context)
class SlaveStream(Stream):
def __init__(self, context, secure_unpickler=True):
super(SlaveStream, self).__init__(context, secure_unpickler)
self.AddHandleCB(self._CallFunction, handle=CALL_FUNCTION)
def _CallFunction(self, handle, killed, data):
Log('%r._CallFunction(%r, %r)', self, handle, data)
try:
reply_handle, mod_name, func_name, args, kwargs = data
try:
module = __import__(mod_name)
except ImportError:
raise # TODO: module source callback.
# (success, data)
self.Enqueue(reply_handle,
(True, getattr(module, func_name)(*args, **kwargs)))
except Exception, e:
self.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
class LocalStream(Stream):
"""
'''
Base for streams capable of starting new slaves.
"""
'''
python_path = property(
lambda self: getattr(self, '_python_path', sys.executable),
lambda self, path: setattr(self, '_python_path', path),
doc='The path to the remote Python interpreter.')
def _GetModuleSource(self, name):
def _GetModuleSource(self, killed, name):
if not killed:
return inspect.getsource(sys.modules[name])
def __init__(self, context, secure_unpickler=True):
super(LocalStream, self).__init__(context, secure_unpickler)
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE_SOURCE)
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_modules = {}
self._unpickler.find_global = self._FindGlobal
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE)
def _FindGlobal(self, module_name, class_name):
'''
See the cPickle documentation: given a module and class name, determine
whether class referred to is safe for unpickling.
Args:
module_name: str
class_name: str
Returns:
classobj or type
'''
if module_name not in self._permitted_modules:
raise StreamError('context %r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowModule(self, module_name):
'''
Add the given module to the list of permitted modules.
Args:
module_name: str
'''
self._permitted_modules.add(module_name)
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
@ -441,7 +496,7 @@ class LocalStream(Stream):
self, self._wfd, self._rfd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContextImpl.Main(%r)\n' % (self._context.name,)
source += '\nExternalContextMain(%r)\n' % (self._context.name,)
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
@ -470,9 +525,9 @@ class SSHStream(LocalStream):
class Context(object):
"""
'''
Represents a remote context regardless of current connection method.
"""
'''
def __init__(self, manager, name=None, hostname=None, username=None):
self.manager = manager
@ -489,34 +544,44 @@ class Context(object):
self._stream = stream
return stream
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
kwargs)
handle = self._stream.AllocHandle()
def EnqueueAwaitReply(self, handle, deadline, data):
'''
Send a message to the given handle and wait for a response with an optional
timeout. The message contains (reply_handle, data), where reply_handle is
the handle on which this function expects its reply.
'''
Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data)
reply_handle = self._stream.AllocHandle()
reply_event = threading.Event()
container = []
def _Receive(handle, killed, data):
Log('%r._Receive(%r, %r, %r)', self, handle, killed, data)
def _Receive(killed, data):
Log('%r._Receive(%r, %r)', self, killed, data)
container.extend([killed, data])
reply_event.set()
self._stream.AddHandleCB(_Receive, handle, persist=False)
call = (handle, fn.__module__, fn.__name__, args, kwargs)
self._stream.Enqueue(CALL_FUNCTION, call)
self._stream.AddHandleCB(_Receive, reply_handle, persist=False)
self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data))
reply_event.wait(deadline)
if not reply_event.isSet():
self.Disconnect()
raise TimeoutError('deadline exceeded.')
Log('%r._Receive(): got reply, container is %r', self, container)
killed, data = container
if killed:
raise StreamError('lost connection during call.')
success, result = data
Log('%r._EnqueueAwaitReply(): got reply: %r', self, data)
return data
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
kwargs)
call = (fn.__module__, fn.__name__, args, kwargs)
success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
if success:
return result
else:
@ -527,10 +592,6 @@ class Context(object):
def Call(self, fn, *args, **kwargs):
return self.CallWithDeadline(fn, None, *args, **kwargs)
def Kill(self, deadline=30):
self.CallWithDeadline(os.kill, deadline,
-self.Call(os.getpgrp), signal.SIGTERM)
def __repr__(self):
bits = map(repr, filter(None, [self.name, self.hostname, self.username]))
return 'Context(%s)' % ', '.join(bits)
@ -543,11 +604,6 @@ class ContextManager(object):
'''
def __init__(self):
self._scheduler = sched.scheduler(time.time, self.OneShot)
self._idle_timeout = 0
self._dead = False
self._kill_on_empty = False
self._poller = select.poll()
self._poller_fd_map = {}
@ -560,19 +616,15 @@ class ContextManager(object):
self._wake_rfd, self._wake_wfd = os.pipe()
self._poller.register(self._wake_rfd)
def SetKillOnEmpty(self, kill_on_empty=True):
'''
Indicate the main loop should exit when there are no remaining sessions
open.
'''
self._kill_on_empty = kill_on_empty
self._thread = threading.Thread(target=self.Loop, name='ContextManager')
self._thread.setDaemon(True)
self._thread.start()
self._dead = False
def Register(self, context):
'''
Put a context under control of this manager.
'''
self._contexts_lock.acquire()
try:
self._contexts[context.name] = context
@ -590,21 +642,19 @@ class ContextManager(object):
Returns:
econtext.Context
'''
context = Context(self, name)
context.SetStream(LocalStream(context)).Connect()
return self.Register(context)
def GetRemote(self, hostname, name=None, username=None):
def GetRemote(self, hostname, username=None, name=None):
'''
Return the named remote context, or create it if it doesn't exist.
'''
if username is None:
username = getpass.getuser()
if name is None:
name = 'econtext[%s@%s:%d]' %\
(getpass.getuser(), socket.gethostname(), os.getpid())
(getpass.getuser(), os.getenv('HOSTNAME'), os.getpid())
context = Context(self, name, hostname, username)
context.SetStream(SSHStream(context)).Connect()
@ -612,22 +662,17 @@ class ContextManager(object):
def UpdateStreamIOState(self, stream):
'''
Update the manager's internal state regarding the specified stream. This
marks its FDs for polling as appropriate, and resets its idle counter.
Update the manager's internal state regarding the specified stream by
marking its FDs for polling as appropriate.
Args:
stream: econtext.Stream
'''
Log('%r.UpdateStreamIOState(%r)', self, stream)
self._poller_changes_lock.acquire()
try:
self._poller_changes[stream] = None
if self._idle_timeout:
if stream._sched_id:
self._scheduler.cancel(stream._sched_id)
self._scheduler.enter(self._idle_timeout, 0, stream.Disconnect, ())
finally:
self._poller_changes_lock.release()
os.write(self._wake_wfd, ' ')
@ -638,7 +683,6 @@ class ContextManager(object):
UpdateStreamIOState. Poller registration updates must be done in serial
with calls to its poll() method.
'''
Log('%r._DoChangedStreams()', self)
self._poller_changes_lock.acquire()
@ -651,20 +695,11 @@ class ContextManager(object):
for stream in changes:
alive, ifd, ofd, has_output = stream.GetIOState()
if not alive: # no fd = closed stream.
Log('here2')
if not alive:
for fd in (ifd, ofd):
try:
self._poller.unregister(fd)
Log('unregistered fd=%d from poller', fd)
except KeyError:
Log('failed to unregister fd=%d from poller', fd)
try:
del self._poller_fd_map[fd]
Log('unregistered fd=%d from poller map', fd)
except KeyError:
Log('failed to unregister fd=%d from poller map', fd)
del self._contexts[stream._context]
return
if has_output:
self._poller.register(ofd, select.POLLOUT)
@ -676,26 +711,17 @@ class ContextManager(object):
self._poller.register(ifd, select.POLLIN)
self._poller_fd_map[ifd] = stream
def OneShot(self, timeout=None):
def Loop(self):
'''
Poll once for I/O and return after all processing is complete, optionally
terminating after some number of seconds.
Args:
timeout: int or float
Handle stream events until Finalize() is called.
'''
if timeout == 0: # scheduler behaviour we don't require.
return
Log('%r.OneShot(): _poller_fd_map=%r', self, self._poller_fd_map)
for fd, event in self._poller.poll(timeout):
while not self._dead:
Log('%r.Loop(): %r', self, self._poller_fd_map)
for fd, event in self._poller.poll():
if fd == self._wake_rfd:
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
os.read(self._wake_rfd, 1)
self._DoChangedStreams()
break
elif event & select.POLLHUP:
Log('%r: POLLHUP on %d; calling %r', self, fd,
self._poller_fd_map[fd].Disconnect)
@ -705,45 +731,18 @@ class ContextManager(object):
self._poller_fd_map[fd].Receive)
self._poller_fd_map[fd].Receive()
elif event & select.POLLOUT:
Log('%r: POLLOUT on %d', self, fd)
Log('%r: POLLOUT on %d; calling %r', self, fd,
self._poller_fd_map[fd].Transmit)
self._poller_fd_map[fd].Transmit()
elif event & select.POLLNVAL:
# GAY
Log('%r: POLLNVAL for %d, unregistering it.', self, fd)
self._poller.unregister(fd)
def Loop(self):
'''
Handle stream events until Finalize() is called.
'''
while (not self._dead) or (self._kill_on_empty and not self._contexts):
# TODO: why the fuck is self._scheduler.empty() returning True?!
if not len(self._scheduler.queue):
self.OneShot()
else:
Log('self._scheduler.empty() -> %r', self._scheduler.empty())
Log('not not self._scheduler.queue -> %r',
not not self._scheduler.queue)
Log('%r._scheduler.run() -> %r', self, self._scheduler.queue)
raise SystemExit
self._scheduler.run()
def SetIdleTimeout(self, timeout):
'''
Set the number of seconds after which an idle stream connected to a remote
context is eligible for disconnection.
Args:
timeout: int or float
'''
self._idle_timeout = timeout
def Finalize(self):
'''
Tell all active streams to disconnect.
'''
self._dead = True
self._contexts_lock.acquire()
try:
@ -758,100 +757,28 @@ class ContextManager(object):
return 'econtext.ContextManager(<contexts=%s>)' % (self._contexts.keys(),)
class ExternalContextImpl(object):
def Main(cls, context_name):
def ExternalContextMain(context_name):
Log('ExternalContextMain(%r)', context_name)
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
parent_host = os.getenv('SSH_CLIENT')
syslog.syslog('initializing (parent_host=%s)' % (parent_host,))
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
os.dup2(100, 0)
os.close(100)
manager = ContextManager()
manager.SetKillOnEmpty()
context = Context(manager, 'parent')
stream = context.SetStream(SlaveStream.FromFDs(context, rfd=0, wfd=1))
stream = context.SetStream(Stream.FromFDs(context, rfd=0, wfd=1))
manager.Register(context)
try:
manager.Loop()
except StreamError, e:
syslog.syslog('exit: ' + str(e))
os.kill(-os.getpgrp(), signal.SIGKILL)
Main = classmethod(Main)
def __repr__(self):
return 'ExternalContextImpl(%r)' % (self.name,)
#
# Simple interface.
#
def Init(idle_secs=60*60):
'''
Initialize the simple interface.
Args:
# Seconds to keep an unused context alive or None for infinite.
idle_secs: 3600 or None
'''
global _manager
global _manager_thread
if _manager:
return _manager
_manager = ContextManager()
_manager.SetIdleTimeout(idle_secs)
_manager_thread = threading.Thread(target=_manager.Loop)
_manager_thread.setDaemon(True)
_manager_thread.start()
atexit.register(Finalize)
return _manager
for call_info in Channel(stream, CALL_FUNCTION):
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)
reply_handle, mod_name, func_name, args, kwargs = call_info
fn = getattr(__import__(mod_name), func_name)
def Finalize():
global _manager
global _manager_thread
if _manager is not None:
_manager.Finalize()
_manager = None
def CallWithDeadline(hostname, username, fn, deadline, *args, **kwargs):
'''
Make a function call in the context of a remote host. Set a maximum deadline
in seconds after which it is assumed the call failed.
Args:
# Hostname or address of remote host.
hostname: str
# Username to connect as, or None for current user.
username: str or None
# Seconds until we assume the call has failed.
deadline: float or None
# The function to execute in the remote context.
fn: staticmethod or classmethod or types.FunctionType
Returns:
# Function's return value.
object
'''
context = Init().GetRemote(hostname, username=username)
return context.CallWithDeadline(fn, deadline, *args, **kwargs)
def Call(hostname, username, fn, *args, **kwargs):
'''
Like CallWithDeadline, but with no deadline.
'''
return CallWithDeadline(hostname, username, fn, None, *args, **kwargs)
try:
stream.Enqueue(reply_handle, (True, fn(*args, **kwargs)))
except Exception, e:
stram.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))

Loading…
Cancel
Save