econtext-20070514-1643
commit
b35689bfe8
@ -0,0 +1,857 @@
|
|||||||
|
#!/usr/bin/env python2.5
|
||||||
|
|
||||||
|
'''
|
||||||
|
Python External Execution Contexts.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import cPickle
|
||||||
|
import cStringIO
|
||||||
|
import commands
|
||||||
|
import getpass
|
||||||
|
import imp
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import sched
|
||||||
|
import select
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import syslog
|
||||||
|
import textwrap
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import types
|
||||||
|
import zlib
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Module-level data.
|
||||||
|
#
|
||||||
|
|
||||||
|
GET_MODULE_SOURCE = 0L
|
||||||
|
CALL_FUNCTION = 1L
|
||||||
|
|
||||||
|
_manager = None
|
||||||
|
_manager_thread = None
|
||||||
|
|
||||||
|
DEBUG = True
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Exceptions.
|
||||||
|
#
|
||||||
|
|
||||||
|
class ContextError(Exception):
|
||||||
|
'Raised when a problem occurs with a context.'
|
||||||
|
def __init__(self, fmt, *args):
|
||||||
|
Exception.__init__(self, fmt % args)
|
||||||
|
|
||||||
|
class StreamError(ContextError):
|
||||||
|
'Raised when a stream cannot be established.'
|
||||||
|
|
||||||
|
class CorruptMessageError(StreamError):
|
||||||
|
'Raised when a corrupt message is received on a stream.'
|
||||||
|
|
||||||
|
class TimeoutError(StreamError):
|
||||||
|
'Raised when a timeout occurs on a stream.'
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Helpers.
|
||||||
|
#
|
||||||
|
|
||||||
|
def Log(fmt, *args):
|
||||||
|
if DEBUG:
|
||||||
|
sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(),
|
||||||
|
(fmt%args).replace('econtext.', '')))
|
||||||
|
|
||||||
|
|
||||||
|
class PartialFunction(object):
|
||||||
|
def __init__(self, fn, *partial_args):
|
||||||
|
self.fn = fn
|
||||||
|
self.partial_args = partial_args
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.fn(*(self.partial_args+args), **kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionProxy(object):
|
||||||
|
__slots__ = ['_context', '_per_id']
|
||||||
|
|
||||||
|
def __init__(self, context, per_id):
|
||||||
|
self._context = context
|
||||||
|
self._per_id = per_id
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
self._context._Call(self._per_id, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SlaveModuleImporter(object):
|
||||||
|
'''
|
||||||
|
This objects implements the import hook protocol defined in
|
||||||
|
http://www.python.org/dev/peps/pep-0302/; the interpreter will ask it if it
|
||||||
|
knows how to load each module, it will in turn ask the interpreter if it
|
||||||
|
knows how to do the load, and if so, it will say it can't. This round about
|
||||||
|
crap is necessary because the module import mechanism is brutal.
|
||||||
|
|
||||||
|
When the built in importer can't load a module, we try requesting it from the
|
||||||
|
parent context.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, context):
|
||||||
|
self._context = context
|
||||||
|
|
||||||
|
def find_module(self, fullname, path=None):
|
||||||
|
if imp.find_module(fullname):
|
||||||
|
return
|
||||||
|
return self
|
||||||
|
|
||||||
|
def load_module(self, fullname):
|
||||||
|
kind, data = self._context.
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Stream implementations.
|
||||||
|
#
|
||||||
|
|
||||||
|
class Stream(object):
|
||||||
|
def __init__(self, context, secure_unpickler=True):
|
||||||
|
self._context = context
|
||||||
|
self._sched_id = 0.0
|
||||||
|
self._alive = True
|
||||||
|
|
||||||
|
self._input_buf = self._output_buf = ''
|
||||||
|
self._input_buf_lock = threading.Lock()
|
||||||
|
self._output_buf_lock = threading.Lock()
|
||||||
|
|
||||||
|
self._last_handle = 0
|
||||||
|
self._handle_map = {}
|
||||||
|
self._handle_lock = threading.Lock()
|
||||||
|
|
||||||
|
self._func_refs = {}
|
||||||
|
self._func_ref_lock = threading.Lock()
|
||||||
|
|
||||||
|
self._pickler_file = cStringIO.StringIO()
|
||||||
|
self._pickler = cPickle.Pickler(self._pickler_file)
|
||||||
|
self._pickler.persistent_id = self._CheckFunctionPerID
|
||||||
|
|
||||||
|
self._unpickler_file = cStringIO.StringIO()
|
||||||
|
self._unpickler = cPickle.Unpickler(self._unpickler_file)
|
||||||
|
self._unpickler.persistent_load = self._LoadFunctionFromPerID
|
||||||
|
|
||||||
|
if secure_unpickler:
|
||||||
|
self._permitted_modules = {}
|
||||||
|
self._unpickler.find_global = self._FindGlobal
|
||||||
|
|
||||||
|
# Pickler/Unpickler support.
|
||||||
|
|
||||||
|
def _CheckFunctionPerID(self, obj):
|
||||||
|
'''
|
||||||
|
Please see the cPickle documentation. Given an object, return None
|
||||||
|
indicating normal pickle processing or a string 'persistent ID'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str
|
||||||
|
'''
|
||||||
|
|
||||||
|
if isinstance(obj, (types.FunctionType, types.MethodType)):
|
||||||
|
pid = 'FUNC:' + repr(obj)
|
||||||
|
self._func_refs[per_id] = obj
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def _LoadFunctionFromPerID(self, pid):
|
||||||
|
'''
|
||||||
|
Please see the cPickle documentation. Given a string created by
|
||||||
|
_CheckFunctionPerID, turn it into an object again.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pid: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object
|
||||||
|
'''
|
||||||
|
|
||||||
|
if not pid.startswith('FUNC:'):
|
||||||
|
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
|
||||||
|
return FunctionProxy(self, pid)
|
||||||
|
|
||||||
|
def _FindGlobal(self, module_name, class_name):
|
||||||
|
'''
|
||||||
|
Please see the cPickle documentation. Given a module and class name,
|
||||||
|
determine whether class referred to is safe for unpickling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: str
|
||||||
|
class_name: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
classobj or type
|
||||||
|
'''
|
||||||
|
|
||||||
|
if module_name not in self._permitted_modules:
|
||||||
|
raise StreamError('context %r attempted to unpickle %r in module %r',
|
||||||
|
self._context, class_name, module_name)
|
||||||
|
return getattr(sys.modules[module_name], class_name)
|
||||||
|
|
||||||
|
def AllowModule(self, module_name):
|
||||||
|
'''
|
||||||
|
Add the given module to the list of permitted modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: str
|
||||||
|
'''
|
||||||
|
self._permitted_modules.add(module_name)
|
||||||
|
|
||||||
|
# I/O.
|
||||||
|
|
||||||
|
def AllocHandle(self):
|
||||||
|
'''
|
||||||
|
Allocate a unique communications handle for this stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
long
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._handle_lock.acquire()
|
||||||
|
try:
|
||||||
|
self._last_handle += 1L
|
||||||
|
finally:
|
||||||
|
self._handle_lock.release()
|
||||||
|
return self._last_handle
|
||||||
|
|
||||||
|
def AddHandleCB(self, fn, handle, persist=True):
|
||||||
|
'''
|
||||||
|
Arrange to invoke the given function for all messages tagged with the given
|
||||||
|
handle. By default, process one message and discard this arrangement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: callable
|
||||||
|
handle: long
|
||||||
|
persist: bool
|
||||||
|
'''
|
||||||
|
|
||||||
|
Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist)
|
||||||
|
self._handle_lock.acquire()
|
||||||
|
try:
|
||||||
|
self._handle_map[handle] = persist, fn
|
||||||
|
finally:
|
||||||
|
self._handle_lock.release()
|
||||||
|
|
||||||
|
def Receive(self):
|
||||||
|
'''
|
||||||
|
Handle the next complete message on the stream. Raise CorruptMessageError
|
||||||
|
or IOError on failure.
|
||||||
|
'''
|
||||||
|
|
||||||
|
chunk = os.read(self._rfd, 4096)
|
||||||
|
if not chunk:
|
||||||
|
raise StreamError('remote side hung up.')
|
||||||
|
|
||||||
|
self._input_buf += chunk
|
||||||
|
buffer_len = len(self._input_buf)
|
||||||
|
if buffer_len < 4:
|
||||||
|
return
|
||||||
|
|
||||||
|
msg_len = struct.unpack('>L', self._input_buf[:4])[0]
|
||||||
|
if buffer_len < msg_len-4:
|
||||||
|
return
|
||||||
|
|
||||||
|
Log('%r.Receive() -> msg_len=%d; msg=%r', self, msg_len,
|
||||||
|
self._input_buf[4:msg_len+4])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: wire in the per-instance unpickler.
|
||||||
|
handle, data = cPickle.loads(self._input_buf[4:msg_len+4])
|
||||||
|
self._input_buf = self._input_buf[msg_len+4:]
|
||||||
|
handle = long(handle)
|
||||||
|
|
||||||
|
Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data)
|
||||||
|
persist, fn = self._handle_map[handle]
|
||||||
|
if not persist:
|
||||||
|
del self._handle_map[handle]
|
||||||
|
except KeyError, ex:
|
||||||
|
raise CorruptMessageError('%r got invalid handle: %r', self, handle)
|
||||||
|
except (TypeError, ValueError), ex:
|
||||||
|
raise CorruptMessageError('%r got invalid message: %s', self, ex)
|
||||||
|
|
||||||
|
fn(handle, False, data)
|
||||||
|
|
||||||
|
def Transmit(self):
|
||||||
|
'''
|
||||||
|
Transmit pending messages. Raises IOError on failure.
|
||||||
|
'''
|
||||||
|
|
||||||
|
written = os.write(self._wfd, self._output_buf[:4096])
|
||||||
|
self._output_buf = self._output_buf[written:]
|
||||||
|
if self._context and not self._output_buf:
|
||||||
|
self._context.manager.UpdateStreamIOState(self)
|
||||||
|
|
||||||
|
def Disconnect(self):
|
||||||
|
'''
|
||||||
|
Called to handle disconnects.
|
||||||
|
'''
|
||||||
|
|
||||||
|
Log('%r.Disconnect()', self)
|
||||||
|
|
||||||
|
for fd in (self._rfd, self._wfd):
|
||||||
|
try:
|
||||||
|
os.close(fd)
|
||||||
|
Log('%r.Disconnect(): closed fd %d', self, fd)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Invoke each registered non-persistent handle callback to indicate the
|
||||||
|
# connection has been destroyed. This prevents pending RPCs from hanging
|
||||||
|
# infinitely.
|
||||||
|
for handle, (persist, fn) in self._handle_map.iteritems():
|
||||||
|
if not persist:
|
||||||
|
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
|
||||||
|
self, handle, fn)
|
||||||
|
fn(handle, True, None)
|
||||||
|
|
||||||
|
self._context.manager.UpdateStreamIOState(self)
|
||||||
|
|
||||||
|
def GetIOState(self):
|
||||||
|
'''
|
||||||
|
Return a 3-tuple describing the instance's I/O state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(alive, input_fd, output_fd, has_output_buffered)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# TODO: this alive flag is stupid.
|
||||||
|
return self._alive, self._rfd, self._wfd, bool(self._output_buf)
|
||||||
|
|
||||||
|
def Enqueue(self, handle, data):
|
||||||
|
Log('%r.Enqueue(%r, %r)', self, handle, data)
|
||||||
|
|
||||||
|
self._output_buf_lock.acquire()
|
||||||
|
try:
|
||||||
|
# TODO: wire in the per-instance pickler.
|
||||||
|
encoded = cPickle.dumps((handle, data))
|
||||||
|
self._output_buf += struct.pack('>L', len(encoded)) + encoded
|
||||||
|
finally:
|
||||||
|
self._output_buf_lock.release()
|
||||||
|
self._context.manager.UpdateStreamIOState(self)
|
||||||
|
|
||||||
|
# Misc.
|
||||||
|
|
||||||
|
def FromFDs(cls, context, rfd, wfd):
|
||||||
|
Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd)
|
||||||
|
self = cls(context)
|
||||||
|
self._rfd, self._wfd = rfd, wfd
|
||||||
|
return self
|
||||||
|
FromFDs = classmethod(FromFDs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'econtext.%s(<context=%r>)' %\
|
||||||
|
(self.__class__.__name__, self._context)
|
||||||
|
|
||||||
|
|
||||||
|
class SlaveStream(Stream):
|
||||||
|
def __init__(self, context, secure_unpickler=True):
|
||||||
|
super(SlaveStream, self).__init__(context, secure_unpickler)
|
||||||
|
self.AddHandleCB(self._CallFunction, handle=CALL_FUNCTION)
|
||||||
|
|
||||||
|
def _CallFunction(self, handle, killed, data):
|
||||||
|
Log('%r._CallFunction(%r, %r)', self, handle, data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply_handle, mod_name, func_name, args, kwargs = data
|
||||||
|
try:
|
||||||
|
module = __import__(mod_name)
|
||||||
|
except ImportError:
|
||||||
|
raise # TODO: module source callback.
|
||||||
|
# (success, data)
|
||||||
|
self.Enqueue(reply_handle,
|
||||||
|
(True, getattr(module, func_name)(*args, **kwargs)))
|
||||||
|
except Exception, e:
|
||||||
|
self.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStream(Stream):
|
||||||
|
"""
|
||||||
|
Base for streams capable of starting new slaves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
python_path = property(
|
||||||
|
lambda self: getattr(self, '_python_path', sys.executable),
|
||||||
|
lambda self, path: setattr(self, '_python_path', path),
|
||||||
|
doc='The path to the remote Python interpreter.')
|
||||||
|
|
||||||
|
def _GetModuleSource(self, name):
|
||||||
|
return inspect.getsource(sys.modules[name])
|
||||||
|
|
||||||
|
def __init__(self, context, secure_unpickler=True):
|
||||||
|
super(LocalStream, self).__init__(context, secure_unpickler)
|
||||||
|
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE_SOURCE)
|
||||||
|
|
||||||
|
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
|
||||||
|
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
|
||||||
|
# with the context name. Optimized for source size.
|
||||||
|
def _FirstStage():
|
||||||
|
import os,sys,zlib
|
||||||
|
R,W=os.pipe()
|
||||||
|
pid=os.fork()
|
||||||
|
if pid:
|
||||||
|
os.dup2(0,100)
|
||||||
|
os.dup2(R,0)
|
||||||
|
os.close(R)
|
||||||
|
os.close(W)
|
||||||
|
os.execv(sys.executable,(CONTEXT_NAME,))
|
||||||
|
else:
|
||||||
|
os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input())))
|
||||||
|
print 'OK'
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def GetBootCommand(self):
|
||||||
|
source = inspect.getsource(self._FirstStage)
|
||||||
|
source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:]))
|
||||||
|
source = source.replace(' ', '\t')
|
||||||
|
source = source.replace('CONTEXT_NAME', repr(self._context.name))
|
||||||
|
return [ self.python_path, '-c',
|
||||||
|
'exec "%s".decode("hex")' % (source.encode('hex'),) ]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '%s(%s)' % (self.__class__.__name__, self._context)
|
||||||
|
|
||||||
|
# Public.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Accept(cls, fd):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def Connect(self):
|
||||||
|
Log('%r.Connect()', self)
|
||||||
|
self._child = subprocess.Popen(self.GetBootCommand(), stdin=subprocess.PIPE,
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
self._wfd = self._child.stdin.fileno()
|
||||||
|
self._rfd = self._child.stdout.fileno()
|
||||||
|
Log('%r.Connect(): chlid process stdin=%r, stdout=%r',
|
||||||
|
self, self._wfd, self._rfd)
|
||||||
|
|
||||||
|
source = inspect.getsource(sys.modules[__name__])
|
||||||
|
source += '\nExternalContextImpl.Main(%r)\n' % (self._context.name,)
|
||||||
|
compressed = zlib.compress(source)
|
||||||
|
|
||||||
|
preamble = str(len(compressed)) + '\n' + compressed
|
||||||
|
self._child.stdin.write(preamble)
|
||||||
|
self._child.stdin.flush()
|
||||||
|
|
||||||
|
assert os.read(self._rfd, 3) == 'OK\n'
|
||||||
|
|
||||||
|
def Disconnect(self):
|
||||||
|
super(LocalStream, self).Disconnect()
|
||||||
|
os.kill(self._child.pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHStream(LocalStream):
|
||||||
|
ssh_path = property(
|
||||||
|
lambda self: getattr(self, '_ssh_path', 'ssh'),
|
||||||
|
lambda self, path: setattr(self, '_ssh_path', path),
|
||||||
|
doc='The path to the SSH binary.')
|
||||||
|
|
||||||
|
def GetBootCommand(self):
|
||||||
|
bits = [self.ssh_path]
|
||||||
|
if self._context.username:
|
||||||
|
bits += ['-l', self._context.username]
|
||||||
|
bits.append(self._context.hostname)
|
||||||
|
return bits + map(commands.mkarg, super(SSHStream, self).GetBootCommand())
|
||||||
|
|
||||||
|
|
||||||
|
class Context(object):
|
||||||
|
"""
|
||||||
|
Represents a remote context regardless of current connection method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, manager, name=None, hostname=None, username=None):
|
||||||
|
self.manager = manager
|
||||||
|
self.name = name
|
||||||
|
self.hostname = hostname
|
||||||
|
self.username = username
|
||||||
|
self.tcp_port = None
|
||||||
|
self._stream = None
|
||||||
|
|
||||||
|
def GetStream(self):
|
||||||
|
return self._stream
|
||||||
|
|
||||||
|
def SetStream(self, stream):
|
||||||
|
self._stream = stream
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
|
||||||
|
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
|
||||||
|
kwargs)
|
||||||
|
handle = self._stream.AllocHandle()
|
||||||
|
reply_event = threading.Event()
|
||||||
|
container = []
|
||||||
|
|
||||||
|
def _Receive(handle, killed, data):
|
||||||
|
Log('%r._Receive(%r, %r, %r)', self, handle, killed, data)
|
||||||
|
container.extend([killed, data])
|
||||||
|
reply_event.set()
|
||||||
|
|
||||||
|
self._stream.AddHandleCB(_Receive, handle, persist=False)
|
||||||
|
call = (handle, fn.__module__, fn.__name__, args, kwargs)
|
||||||
|
self._stream.Enqueue(CALL_FUNCTION, call)
|
||||||
|
|
||||||
|
reply_event.wait(deadline)
|
||||||
|
if not reply_event.isSet():
|
||||||
|
self.Disconnect()
|
||||||
|
raise TimeoutError('deadline exceeded.')
|
||||||
|
|
||||||
|
Log('%r._Receive(): got reply, container is %r', self, container)
|
||||||
|
killed, data = container
|
||||||
|
|
||||||
|
if killed:
|
||||||
|
raise StreamError('lost connection during call.')
|
||||||
|
|
||||||
|
success, result = data
|
||||||
|
if success:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
exc_obj, traceback = result
|
||||||
|
exc_obj.real_traceback = traceback
|
||||||
|
raise exc_obj
|
||||||
|
|
||||||
|
def Call(self, fn, *args, **kwargs):
|
||||||
|
return self.CallWithDeadline(fn, None, *args, **kwargs)
|
||||||
|
|
||||||
|
def Kill(self, deadline=30):
|
||||||
|
self.CallWithDeadline(os.kill, deadline,
|
||||||
|
-self.Call(os.getpgrp), signal.SIGTERM)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
bits = map(repr, filter(None, [self.name, self.hostname, self.username]))
|
||||||
|
return 'Context(%s)' % ', '.join(bits)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManager(object):
|
||||||
|
'''
|
||||||
|
Context manager: this is responsible for keeping track of contexts, any
|
||||||
|
stream that is associated with them, and for I/O multiplexing.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._scheduler = sched.scheduler(time.time, self.OneShot)
|
||||||
|
self._idle_timeout = 0
|
||||||
|
self._dead = False
|
||||||
|
self._kill_on_empty = False
|
||||||
|
|
||||||
|
self._poller = select.poll()
|
||||||
|
self._poller_fd_map = {}
|
||||||
|
|
||||||
|
self._contexts_lock = threading.Lock()
|
||||||
|
self._contexts = {}
|
||||||
|
|
||||||
|
self._poller_changes_lock = threading.Lock()
|
||||||
|
self._poller_changes = {}
|
||||||
|
|
||||||
|
self._wake_rfd, self._wake_wfd = os.pipe()
|
||||||
|
self._poller.register(self._wake_rfd)
|
||||||
|
|
||||||
|
def SetKillOnEmpty(self, kill_on_empty=True):
|
||||||
|
'''
|
||||||
|
Indicate the main loop should exit when there are no remaining sessions
|
||||||
|
open.
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._kill_on_empty = kill_on_empty
|
||||||
|
|
||||||
|
def Register(self, context):
|
||||||
|
'''
|
||||||
|
Put a context under control of this manager.
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._contexts_lock.acquire()
|
||||||
|
try:
|
||||||
|
self._contexts[context.name] = context
|
||||||
|
self.UpdateStreamIOState(context.GetStream())
|
||||||
|
finally:
|
||||||
|
self._contexts_lock.release()
|
||||||
|
return context
|
||||||
|
|
||||||
|
def GetLocal(self, name):
|
||||||
|
'''
|
||||||
|
Return the named local context, or create it if it doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 'my-local-context'
|
||||||
|
Returns:
|
||||||
|
econtext.Context
|
||||||
|
'''
|
||||||
|
|
||||||
|
context = Context(self, name)
|
||||||
|
context.SetStream(LocalStream(context)).Connect()
|
||||||
|
return self.Register(context)
|
||||||
|
|
||||||
|
def GetRemote(self, hostname, name=None, username=None):
|
||||||
|
'''
|
||||||
|
Return the named remote context, or create it if it doesn't exist.
|
||||||
|
'''
|
||||||
|
|
||||||
|
if username is None:
|
||||||
|
username = getpass.getuser()
|
||||||
|
if name is None:
|
||||||
|
name = 'econtext[%s@%s:%d]' %\
|
||||||
|
(getpass.getuser(), socket.gethostname(), os.getpid())
|
||||||
|
|
||||||
|
context = Context(self, name, hostname, username)
|
||||||
|
context.SetStream(SSHStream(context)).Connect()
|
||||||
|
return self.Register(context)
|
||||||
|
|
||||||
|
def UpdateStreamIOState(self, stream):
|
||||||
|
'''
|
||||||
|
Update the manager's internal state regarding the specified stream. This
|
||||||
|
marks its FDs for polling as appropriate, and resets its idle counter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream: econtext.Stream
|
||||||
|
'''
|
||||||
|
|
||||||
|
Log('%r.UpdateStreamIOState(%r)', self, stream)
|
||||||
|
|
||||||
|
self._poller_changes_lock.acquire()
|
||||||
|
try:
|
||||||
|
self._poller_changes[stream] = None
|
||||||
|
if self._idle_timeout:
|
||||||
|
if stream._sched_id:
|
||||||
|
self._scheduler.cancel(stream._sched_id)
|
||||||
|
self._scheduler.enter(self._idle_timeout, 0, stream.Disconnect, ())
|
||||||
|
finally:
|
||||||
|
self._poller_changes_lock.release()
|
||||||
|
os.write(self._wake_wfd, ' ')
|
||||||
|
|
||||||
|
def _DoChangedStreams(self):
|
||||||
|
'''
|
||||||
|
Walk the list of streams indicated as having an updated I/O state by
|
||||||
|
UpdateStreamIOState. Poller registration updates must be done in serial
|
||||||
|
with calls to its poll() method.
|
||||||
|
'''
|
||||||
|
|
||||||
|
Log('%r._DoChangedStreams()', self)
|
||||||
|
|
||||||
|
self._poller_changes_lock.acquire()
|
||||||
|
try:
|
||||||
|
changes = self._poller_changes.keys()
|
||||||
|
self._poller_changes = {}
|
||||||
|
finally:
|
||||||
|
self._poller_changes_lock.release()
|
||||||
|
|
||||||
|
for stream in changes:
|
||||||
|
alive, ifd, ofd, has_output = stream.GetIOState()
|
||||||
|
|
||||||
|
if not alive: # no fd = closed stream.
|
||||||
|
Log('here2')
|
||||||
|
for fd in (ifd, ofd):
|
||||||
|
try:
|
||||||
|
self._poller.unregister(fd)
|
||||||
|
Log('unregistered fd=%d from poller', fd)
|
||||||
|
except KeyError:
|
||||||
|
Log('failed to unregister fd=%d from poller', fd)
|
||||||
|
try:
|
||||||
|
del self._poller_fd_map[fd]
|
||||||
|
Log('unregistered fd=%d from poller map', fd)
|
||||||
|
except KeyError:
|
||||||
|
Log('failed to unregister fd=%d from poller map', fd)
|
||||||
|
del self._contexts[stream._context]
|
||||||
|
|
||||||
|
if has_output:
|
||||||
|
self._poller.register(ofd, select.POLLOUT)
|
||||||
|
self._poller_fd_map[ofd] = stream
|
||||||
|
elif ofd in self._poller_fd_map:
|
||||||
|
self._poller.unregister(ofd)
|
||||||
|
del self._poller_fd_map[ofd]
|
||||||
|
|
||||||
|
self._poller.register(ifd, select.POLLIN)
|
||||||
|
self._poller_fd_map[ifd] = stream
|
||||||
|
|
||||||
|
def OneShot(self, timeout=None):
|
||||||
|
'''
|
||||||
|
Poll once for I/O and return after all processing is complete, optionally
|
||||||
|
terminating after some number of seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: int or float
|
||||||
|
'''
|
||||||
|
|
||||||
|
if timeout == 0: # scheduler behaviour we don't require.
|
||||||
|
return
|
||||||
|
|
||||||
|
Log('%r.OneShot(): _poller_fd_map=%r', self, self._poller_fd_map)
|
||||||
|
|
||||||
|
for fd, event in self._poller.poll(timeout):
|
||||||
|
if fd == self._wake_rfd:
|
||||||
|
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
|
||||||
|
os.read(self._wake_rfd, 1)
|
||||||
|
self._DoChangedStreams()
|
||||||
|
break
|
||||||
|
elif event & select.POLLHUP:
|
||||||
|
Log('%r: POLLHUP on %d; calling %r', self, fd,
|
||||||
|
self._poller_fd_map[fd].Disconnect)
|
||||||
|
self._poller_fd_map[fd].Disconnect()
|
||||||
|
elif event & select.POLLIN:
|
||||||
|
Log('%r: POLLIN on %d; calling %r', self, fd,
|
||||||
|
self._poller_fd_map[fd].Receive)
|
||||||
|
self._poller_fd_map[fd].Receive()
|
||||||
|
elif event & select.POLLOUT:
|
||||||
|
Log('%r: POLLOUT on %d; calling %r', self, fd,
|
||||||
|
self._poller_fd_map[fd].Transmit)
|
||||||
|
self._poller_fd_map[fd].Transmit()
|
||||||
|
elif event & select.POLLNVAL:
|
||||||
|
# GAY
|
||||||
|
self._poller.unregister(fd)
|
||||||
|
|
||||||
|
def Loop(self):
|
||||||
|
'''
|
||||||
|
Handle stream events until Finalize() is called.
|
||||||
|
'''
|
||||||
|
|
||||||
|
while (not self._dead) or (self._kill_on_empty and not self._contexts):
|
||||||
|
# TODO: why the fuck is self._scheduler.empty() returning True?!
|
||||||
|
if not len(self._scheduler.queue):
|
||||||
|
self.OneShot()
|
||||||
|
else:
|
||||||
|
Log('self._scheduler.empty() -> %r', self._scheduler.empty())
|
||||||
|
Log('not not self._scheduler.queue -> %r',
|
||||||
|
not not self._scheduler.queue)
|
||||||
|
Log('%r._scheduler.run() -> %r', self, self._scheduler.queue)
|
||||||
|
raise SystemExit
|
||||||
|
self._scheduler.run()
|
||||||
|
|
||||||
|
def SetIdleTimeout(self, timeout):
|
||||||
|
'''
|
||||||
|
Set the number of seconds after which an idle stream connected to a remote
|
||||||
|
context is eligible for disconnection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: int or float
|
||||||
|
'''
|
||||||
|
self._idle_timeout = timeout
|
||||||
|
|
||||||
|
def Finalize(self):
|
||||||
|
'''
|
||||||
|
Tell all active streams to disconnect.
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._dead = True
|
||||||
|
self._contexts_lock.acquire()
|
||||||
|
try:
|
||||||
|
for name, context in self._contexts.iteritems():
|
||||||
|
stream = context.GetStream()
|
||||||
|
if stream:
|
||||||
|
stream.Disconnect()
|
||||||
|
finally:
|
||||||
|
self._contexts_lock.release()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'econtext.ContextManager(<contexts=%s>)' % (self._contexts.keys(),)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalContextImpl(object):
|
||||||
|
def Main(cls, context_name):
|
||||||
|
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
|
||||||
|
|
||||||
|
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
|
||||||
|
|
||||||
|
parent_host = os.getenv('SSH_CLIENT')
|
||||||
|
syslog.syslog('initializing (parent_host=%s)' % (parent_host,))
|
||||||
|
|
||||||
|
os.dup2(100, 0)
|
||||||
|
os.close(100)
|
||||||
|
|
||||||
|
manager = ContextManager()
|
||||||
|
manager.SetKillOnEmpty()
|
||||||
|
context = Context(manager, 'parent')
|
||||||
|
|
||||||
|
stream = context.SetStream(SlaveStream.FromFDs(context, rfd=0, wfd=1))
|
||||||
|
manager.Register(context)
|
||||||
|
|
||||||
|
try:
|
||||||
|
manager.Loop()
|
||||||
|
except StreamError, e:
|
||||||
|
syslog.syslog('exit: ' + str(e))
|
||||||
|
os.kill(-os.getpgrp(), signal.SIGKILL)
|
||||||
|
Main = classmethod(Main)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'ExternalContextImpl(%r)' % (self.name,)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Simple interface.
|
||||||
|
#
|
||||||
|
|
||||||
|
def Init(idle_secs=60*60):
|
||||||
|
'''
|
||||||
|
Initialize the simple interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
# Seconds to keep an unused context alive or None for infinite.
|
||||||
|
idle_secs: 3600 or None
|
||||||
|
'''
|
||||||
|
|
||||||
|
global _manager
|
||||||
|
global _manager_thread
|
||||||
|
|
||||||
|
if _manager:
|
||||||
|
return _manager
|
||||||
|
|
||||||
|
_manager = ContextManager()
|
||||||
|
_manager.SetIdleTimeout(idle_secs)
|
||||||
|
_manager_thread = threading.Thread(target=_manager.Loop)
|
||||||
|
_manager_thread.setDaemon(True)
|
||||||
|
_manager_thread.start()
|
||||||
|
atexit.register(Finalize)
|
||||||
|
return _manager
|
||||||
|
|
||||||
|
|
||||||
|
def Finalize():
|
||||||
|
global _manager
|
||||||
|
global _manager_thread
|
||||||
|
|
||||||
|
if _manager is not None:
|
||||||
|
_manager.Finalize()
|
||||||
|
_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def CallWithDeadline(hostname, username, fn, deadline, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Make a function call in the context of a remote host. Set a maximum deadline
|
||||||
|
in seconds after which it is assumed the call failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
# Hostname or address of remote host.
|
||||||
|
hostname: str
|
||||||
|
# Username to connect as, or None for current user.
|
||||||
|
username: str or None
|
||||||
|
# Seconds until we assume the call has failed.
|
||||||
|
deadline: float or None
|
||||||
|
# The function to execute in the remote context.
|
||||||
|
fn: staticmethod or classmethod or types.FunctionType
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
# Function's return value.
|
||||||
|
object
|
||||||
|
'''
|
||||||
|
|
||||||
|
context = Init().GetRemote(hostname, username=username)
|
||||||
|
return context.CallWithDeadline(fn, deadline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def Call(hostname, username, fn, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
Like CallWithDeadline, but with no deadline.
|
||||||
|
'''
|
||||||
|
|
||||||
|
return CallWithDeadline(hostname, username, fn, None, *args, **kwargs)
|
||||||
@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python2.5
|
||||||
|
|
||||||
|
"""
|
||||||
|
def DoStuff():
|
||||||
|
import time
|
||||||
|
file('/tmp/foobar', 'w').write(time.ctime())
|
||||||
|
|
||||||
|
|
||||||
|
localhost = pyrpc.SSHConnection('localhost')
|
||||||
|
localhost.Connect()
|
||||||
|
try:
|
||||||
|
ret = localhost.Evaluate(DoStuff)
|
||||||
|
except OSError, e:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import econtext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Helper functions.
|
||||||
|
#
|
||||||
|
|
||||||
|
class GetModuleImportsTestCase(unittest.TestCase):
|
||||||
|
# This must be kept in sync with our actual imports.
|
||||||
|
IMPORTS = [
|
||||||
|
('econtext', 'econtext'),
|
||||||
|
('sys', 'PythonSystemModule'),
|
||||||
|
('sys', 'sys'),
|
||||||
|
('unittest', 'unittest')
|
||||||
|
]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
global PythonSystemModule
|
||||||
|
import sys as PythonSystemModule
|
||||||
|
|
||||||
|
def tearDown(Self):
|
||||||
|
global PythonSystemModule
|
||||||
|
del PythonSystemModule
|
||||||
|
|
||||||
|
def testImports(self):
|
||||||
|
self.assertEqual(set(self.IMPORTS),
|
||||||
|
set(econtext.GetModuleImports(sys.modules[__name__])))
|
||||||
|
|
||||||
|
|
||||||
|
class BuildPartialModuleTestCase(unittest.TestCase):
|
||||||
|
def testNullModule(self):
|
||||||
|
"""Pass empty sequences; result should contain nothing but a hash bang line
|
||||||
|
and whitespace."""
|
||||||
|
|
||||||
|
lines = econtext.BuildPartialModule([], []).strip().split('\n')
|
||||||
|
|
||||||
|
self.assert_(lines[0].startswith('#!'))
|
||||||
|
self.assert_('import' not in lines[1:])
|
||||||
|
|
||||||
|
def testPassingMethodTypeFails(self):
|
||||||
|
"""Pass an instance method and ensure we refuse it."""
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, econtext.BuildPartialModule,
|
||||||
|
[self.testPassingMethodTypeFails], [])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exampleStaticMethod():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testStaticMethodGetsUnwrapped(self):
|
||||||
|
"Ensure that @staticmethod decorators are stripped."
|
||||||
|
|
||||||
|
dct = {}
|
||||||
|
exec econtext.BuildPartialModule([self.exampleStaticMethod], []) in dct
|
||||||
|
self.assertFalse(isinstance(dct['exampleStaticMethod'], staticmethod))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Streams
|
||||||
|
#
|
||||||
|
|
||||||
|
class StreamTestBase:
|
||||||
|
"""This defines rules that should remain true for all Stream subclasses. We
|
||||||
|
test in this manner to guard against a subclass breaking Stream's
|
||||||
|
polymorphism (e.g. overriding a method with the wrong prototype).
|
||||||
|
|
||||||
|
def testCommandLine(self):
|
||||||
|
print self.driver.command_line
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SSHStreamTestCase(unittest.TestCase, StreamTestBase):
|
||||||
|
DRIVER_CLASS = econtext.SSHStream
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Stubs.
|
||||||
|
|
||||||
|
# Instance initialization.
|
||||||
|
self.stream = econtext.SSHStream('localhost', 'test-agent')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def testConstructor(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TCPStreamTestCase(unittest.TestCase, StreamTestBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Run the tests.
|
||||||
|
#
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue