|
|
|
@ -12,17 +12,14 @@ import getpass
|
|
|
|
|
import imp
|
|
|
|
|
import inspect
|
|
|
|
|
import os
|
|
|
|
|
import sched
|
|
|
|
|
import select
|
|
|
|
|
import signal
|
|
|
|
|
import socket
|
|
|
|
|
import struct
|
|
|
|
|
import subprocess
|
|
|
|
|
import sys
|
|
|
|
|
import syslog
|
|
|
|
|
import textwrap
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
import traceback
|
|
|
|
|
import types
|
|
|
|
|
import zlib
|
|
|
|
@ -32,12 +29,9 @@ import zlib
|
|
|
|
|
# Module-level data.
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
GET_MODULE_SOURCE = 0L
|
|
|
|
|
GET_MODULE = 0L
|
|
|
|
|
CALL_FUNCTION = 1L
|
|
|
|
|
|
|
|
|
|
_manager = None
|
|
|
|
|
_manager_thread = None
|
|
|
|
|
|
|
|
|
|
DEBUG = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -50,6 +44,9 @@ class ContextError(Exception):
|
|
|
|
|
def __init__(self, fmt, *args):
|
|
|
|
|
Exception.__init__(self, fmt % args)
|
|
|
|
|
|
|
|
|
|
class ChannelError(ContextError):
|
|
|
|
|
'Raised when a channel dies or has been closed.'
|
|
|
|
|
|
|
|
|
|
class StreamError(ContextError):
|
|
|
|
|
'Raised when a stream cannot be established.'
|
|
|
|
|
|
|
|
|
@ -82,49 +79,132 @@ class PartialFunction(object):
|
|
|
|
|
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FunctionProxy(object):
|
|
|
|
|
__slots__ = ['_context', '_per_id']
|
|
|
|
|
class Channel(object):
|
|
|
|
|
def __init__(self, stream, handle):
|
|
|
|
|
self._stream = stream
|
|
|
|
|
self._handle = handle
|
|
|
|
|
self._wake_event = threading.Event()
|
|
|
|
|
self._queue_lock = threading.Lock()
|
|
|
|
|
self._queue = []
|
|
|
|
|
self._stream.AddHandleCB(self._InternalReceive, handle)
|
|
|
|
|
|
|
|
|
|
def __init__(self, context, per_id):
|
|
|
|
|
self._context = context
|
|
|
|
|
self._per_id = per_id
|
|
|
|
|
def _InternalReceive(self, killed, data):
|
|
|
|
|
'''
|
|
|
|
|
Callback from the stream object; appends a tuple of
|
|
|
|
|
(killed-or-closed, data) to the internal queue and wakes the internal
|
|
|
|
|
event.
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
self._context._Call(self._per_id, args, kwargs)
|
|
|
|
|
Args:
|
|
|
|
|
# Has the Stream object lost its connection?
|
|
|
|
|
killed: bool
|
|
|
|
|
# Has the remote Channel had Close() called? / the object passed to the
|
|
|
|
|
# remote Send().
|
|
|
|
|
data: (bool, object)
|
|
|
|
|
'''
|
|
|
|
|
Log('%r._InternalReceive(%r, %r)', self, killed, data)
|
|
|
|
|
self._queue_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
self._queue.append((killed or data[0], data[1]))
|
|
|
|
|
self._wake_event.set()
|
|
|
|
|
finally:
|
|
|
|
|
self._queue_lock.release()
|
|
|
|
|
|
|
|
|
|
def Close(self):
|
|
|
|
|
'''
|
|
|
|
|
Indicate this channel is closed to the remote side.
|
|
|
|
|
'''
|
|
|
|
|
Log('%r.Close()', self)
|
|
|
|
|
self._stream.Enqueue(handle, (True, None))
|
|
|
|
|
|
|
|
|
|
def Send(self, data):
|
|
|
|
|
'''
|
|
|
|
|
Send the given object to the remote side.
|
|
|
|
|
'''
|
|
|
|
|
Log('%r.Send(%r)', self, data)
|
|
|
|
|
self._stream.Enqueue(handle, (False, data))
|
|
|
|
|
|
|
|
|
|
def Receive(self, timeout=None):
|
|
|
|
|
'''
|
|
|
|
|
Receive the next object to arrive on this channel, or return if the
|
|
|
|
|
optional timeout is reached.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
timeout: float
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
object
|
|
|
|
|
'''
|
|
|
|
|
Log('%r.Receive(%r)', self, timeout)
|
|
|
|
|
if not self._queue:
|
|
|
|
|
self._wake_event.wait(timeout)
|
|
|
|
|
if not self._wake_event.isSet():
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self._queue_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
self._wake_event.clear()
|
|
|
|
|
Log('%r.Receive() queue is %r', self, self._queue)
|
|
|
|
|
closed, data = self._queue.pop(0)
|
|
|
|
|
Log('%r.Receive() got closed=%r, data=%r', self, closed, data)
|
|
|
|
|
if closed:
|
|
|
|
|
raise ChannelError('Channel is closed.')
|
|
|
|
|
return data
|
|
|
|
|
finally:
|
|
|
|
|
self._queue_lock.release()
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
'''
|
|
|
|
|
Return an iterator that yields objects arriving on this channel, until the
|
|
|
|
|
channel is closed.
|
|
|
|
|
'''
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
yield self.Receive()
|
|
|
|
|
except ChannelError:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return 'econtext.Channel(%r, %r)' % (self._stream, self._handle)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SlaveModuleImporter(object):
|
|
|
|
|
'''
|
|
|
|
|
This objects implements the import hook protocol defined in
|
|
|
|
|
http://www.python.org/dev/peps/pep-0302/; the interpreter will ask it if it
|
|
|
|
|
knows how to load each module, it will in turn ask the interpreter if it
|
|
|
|
|
knows how to do the load, and if so, it will say it can't. This round about
|
|
|
|
|
crap is necessary because the module import mechanism is brutal.
|
|
|
|
|
|
|
|
|
|
When the built in importer can't load a module, we try requesting it from the
|
|
|
|
|
parent context.
|
|
|
|
|
This implements the import protocol described in PEP 302. It works like so:
|
|
|
|
|
|
|
|
|
|
- Python asks it if it can import a module.
|
|
|
|
|
- It asks Python (via imp module) if it can import the module.
|
|
|
|
|
- If Python says yes, it says no.
|
|
|
|
|
- If Python says no, it asks the parent context for the module.
|
|
|
|
|
- If the module isn't returned by the parent, asplode, otherwise ask Python
|
|
|
|
|
to load the returned module.
|
|
|
|
|
|
|
|
|
|
This roundabout crap is necessary because the built-in importer is tried only
|
|
|
|
|
after custom hooks are. A class method is provided for the parent context to
|
|
|
|
|
satisfy the module request; it will only return modules that have been loaded
|
|
|
|
|
in the parent context.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def __init__(self, context):
|
|
|
|
|
self._context = context
|
|
|
|
|
|
|
|
|
|
def find_module(self, fullname, path=None):
|
|
|
|
|
if imp.find_module(fullname):
|
|
|
|
|
return
|
|
|
|
|
return self
|
|
|
|
|
if not imp.find_module(fullname):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def load_module(self, fullname):
|
|
|
|
|
kind, data = self._context.
|
|
|
|
|
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
|
|
|
|
|
|
|
|
|
|
def GetModule(cls, fullname):
|
|
|
|
|
if fullname in sys.modules:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# Stream implementations.
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
class Stream(object):
|
|
|
|
|
def __init__(self, context, secure_unpickler=True):
|
|
|
|
|
def __init__(self, context):
|
|
|
|
|
self._context = context
|
|
|
|
|
self._sched_id = 0.0
|
|
|
|
|
self._alive = True
|
|
|
|
|
|
|
|
|
|
self._input_buf = self._output_buf = ''
|
|
|
|
@ -146,12 +226,24 @@ class Stream(object):
|
|
|
|
|
self._unpickler = cPickle.Unpickler(self._unpickler_file)
|
|
|
|
|
self._unpickler.persistent_load = self._LoadFunctionFromPerID
|
|
|
|
|
|
|
|
|
|
if secure_unpickler:
|
|
|
|
|
self._permitted_modules = {}
|
|
|
|
|
self._unpickler.find_global = self._FindGlobal
|
|
|
|
|
|
|
|
|
|
# Pickler/Unpickler support.
|
|
|
|
|
|
|
|
|
|
def Pickle(self, obj):
|
|
|
|
|
self._pickler.dump(obj)
|
|
|
|
|
data = self._pickler_file.getvalue()
|
|
|
|
|
self._pickler_file.seek(0)
|
|
|
|
|
self._pickler_file.truncate(0)
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def Unpickle(self, data):
|
|
|
|
|
Log('%r.Unpickle(%r)', self, data)
|
|
|
|
|
self._unpickler_file.write(data)
|
|
|
|
|
self._unpickler_file.seek(0)
|
|
|
|
|
data = self._unpickler.load()
|
|
|
|
|
self._unpickler_file.seek(0)
|
|
|
|
|
self._unpickler_file.truncate(0)
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def _CheckFunctionPerID(self, obj):
|
|
|
|
|
'''
|
|
|
|
|
Please see the cPickle documentation. Given an object, return None
|
|
|
|
@ -163,7 +255,6 @@ class Stream(object):
|
|
|
|
|
Returns:
|
|
|
|
|
str
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if isinstance(obj, (types.FunctionType, types.MethodType)):
|
|
|
|
|
pid = 'FUNC:' + repr(obj)
|
|
|
|
|
self._func_refs[per_id] = obj
|
|
|
|
@ -180,48 +271,19 @@ class Stream(object):
|
|
|
|
|
Returns:
|
|
|
|
|
object
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if not pid.startswith('FUNC:'):
|
|
|
|
|
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
|
|
|
|
|
return FunctionProxy(self, pid)
|
|
|
|
|
|
|
|
|
|
def _FindGlobal(self, module_name, class_name):
|
|
|
|
|
'''
|
|
|
|
|
Please see the cPickle documentation. Given a module and class name,
|
|
|
|
|
determine whether class referred to is safe for unpickling.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
module_name: str
|
|
|
|
|
class_name: str
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
classobj or type
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if module_name not in self._permitted_modules:
|
|
|
|
|
raise StreamError('context %r attempted to unpickle %r in module %r',
|
|
|
|
|
self._context, class_name, module_name)
|
|
|
|
|
return getattr(sys.modules[module_name], class_name)
|
|
|
|
|
|
|
|
|
|
def AllowModule(self, module_name):
|
|
|
|
|
'''
|
|
|
|
|
Add the given module to the list of permitted modules.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
module_name: str
|
|
|
|
|
'''
|
|
|
|
|
self._permitted_modules.add(module_name)
|
|
|
|
|
return PartialFunction(self._CallPersistentWhatsit, pid)
|
|
|
|
|
|
|
|
|
|
# I/O.
|
|
|
|
|
|
|
|
|
|
def AllocHandle(self):
|
|
|
|
|
'''
|
|
|
|
|
Allocate a unique communications handle for this stream.
|
|
|
|
|
Allocate a unique handle for this stream.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
long
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
self._handle_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
self._last_handle += 1L
|
|
|
|
@ -239,7 +301,6 @@ class Stream(object):
|
|
|
|
|
handle: long
|
|
|
|
|
persist: bool
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist)
|
|
|
|
|
self._handle_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
@ -252,7 +313,7 @@ class Stream(object):
|
|
|
|
|
Handle the next complete message on the stream. Raise CorruptMessageError
|
|
|
|
|
or IOError on failure.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r.Receive()', self)
|
|
|
|
|
chunk = os.read(self._rfd, 4096)
|
|
|
|
|
if not chunk:
|
|
|
|
|
raise StreamError('remote side hung up.')
|
|
|
|
@ -266,12 +327,8 @@ class Stream(object):
|
|
|
|
|
if buffer_len < msg_len-4:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
Log('%r.Receive() -> msg_len=%d; msg=%r', self, msg_len,
|
|
|
|
|
self._input_buf[4:msg_len+4])
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# TODO: wire in the per-instance unpickler.
|
|
|
|
|
handle, data = cPickle.loads(self._input_buf[4:msg_len+4])
|
|
|
|
|
handle, data = self.Unpickle(self._input_buf[4:msg_len+4])
|
|
|
|
|
self._input_buf = self._input_buf[msg_len+4:]
|
|
|
|
|
handle = long(handle)
|
|
|
|
|
|
|
|
|
@ -284,13 +341,13 @@ class Stream(object):
|
|
|
|
|
except (TypeError, ValueError), ex:
|
|
|
|
|
raise CorruptMessageError('%r got invalid message: %s', self, ex)
|
|
|
|
|
|
|
|
|
|
fn(handle, False, data)
|
|
|
|
|
fn(False, data)
|
|
|
|
|
|
|
|
|
|
def Transmit(self):
|
|
|
|
|
'''
|
|
|
|
|
Transmit pending messages. Raises IOError on failure.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r.Transmit()', self)
|
|
|
|
|
written = os.write(self._wfd, self._output_buf[:4096])
|
|
|
|
|
self._output_buf = self._output_buf[written:]
|
|
|
|
|
if self._context and not self._output_buf:
|
|
|
|
@ -300,24 +357,17 @@ class Stream(object):
|
|
|
|
|
'''
|
|
|
|
|
Called to handle disconnects.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r.Disconnect()', self)
|
|
|
|
|
|
|
|
|
|
for fd in (self._rfd, self._wfd):
|
|
|
|
|
try:
|
|
|
|
|
os.close(fd)
|
|
|
|
|
Log('%r.Disconnect(): closed fd %d', self, fd)
|
|
|
|
|
except OSError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Invoke each registered non-persistent handle callback to indicate the
|
|
|
|
|
# connection has been destroyed. This prevents pending RPCs from hanging
|
|
|
|
|
# infinitely.
|
|
|
|
|
os.close(fd)
|
|
|
|
|
|
|
|
|
|
# Invoke each registered callback to indicate the connection has been
|
|
|
|
|
# destroyed. This prevents pending Channels/RPCs from hanging forever.
|
|
|
|
|
for handle, (persist, fn) in self._handle_map.iteritems():
|
|
|
|
|
if not persist:
|
|
|
|
|
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
|
|
|
|
|
self, handle, fn)
|
|
|
|
|
fn(handle, True, None)
|
|
|
|
|
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
|
|
|
|
|
self, handle, fn)
|
|
|
|
|
fn(True, None)
|
|
|
|
|
|
|
|
|
|
self._context.manager.UpdateStreamIOState(self)
|
|
|
|
|
|
|
|
|
@ -328,8 +378,6 @@ class Stream(object):
|
|
|
|
|
Returns:
|
|
|
|
|
(alive, input_fd, output_fd, has_output_buffered)
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
# TODO: this alive flag is stupid.
|
|
|
|
|
return self._alive, self._rfd, self._wfd, bool(self._output_buf)
|
|
|
|
|
|
|
|
|
|
def Enqueue(self, handle, data):
|
|
|
|
@ -337,8 +385,7 @@ class Stream(object):
|
|
|
|
|
|
|
|
|
|
self._output_buf_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
# TODO: wire in the per-instance pickler.
|
|
|
|
|
encoded = cPickle.dumps((handle, data))
|
|
|
|
|
encoded = self.Pickle((handle, data))
|
|
|
|
|
self._output_buf += struct.pack('>L', len(encoded)) + encoded
|
|
|
|
|
finally:
|
|
|
|
|
self._output_buf_lock.release()
|
|
|
|
@ -346,55 +393,63 @@ class Stream(object):
|
|
|
|
|
|
|
|
|
|
# Misc.
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def FromFDs(cls, context, rfd, wfd):
|
|
|
|
|
Log('%r.FromFDs(%r, %r, %r)', cls, context, rfd, wfd)
|
|
|
|
|
self = cls(context)
|
|
|
|
|
self._rfd, self._wfd = rfd, wfd
|
|
|
|
|
return self
|
|
|
|
|
FromFDs = classmethod(FromFDs)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return 'econtext.%s(<context=%r>)' %\
|
|
|
|
|
(self.__class__.__name__, self._context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SlaveStream(Stream):
|
|
|
|
|
def __init__(self, context, secure_unpickler=True):
|
|
|
|
|
super(SlaveStream, self).__init__(context, secure_unpickler)
|
|
|
|
|
self.AddHandleCB(self._CallFunction, handle=CALL_FUNCTION)
|
|
|
|
|
|
|
|
|
|
def _CallFunction(self, handle, killed, data):
|
|
|
|
|
Log('%r._CallFunction(%r, %r)', self, handle, data)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
reply_handle, mod_name, func_name, args, kwargs = data
|
|
|
|
|
try:
|
|
|
|
|
module = __import__(mod_name)
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise # TODO: module source callback.
|
|
|
|
|
# (success, data)
|
|
|
|
|
self.Enqueue(reply_handle,
|
|
|
|
|
(True, getattr(module, func_name)(*args, **kwargs)))
|
|
|
|
|
except Exception, e:
|
|
|
|
|
self.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LocalStream(Stream):
|
|
|
|
|
"""
|
|
|
|
|
'''
|
|
|
|
|
Base for streams capable of starting new slaves.
|
|
|
|
|
"""
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
python_path = property(
|
|
|
|
|
lambda self: getattr(self, '_python_path', sys.executable),
|
|
|
|
|
lambda self, path: setattr(self, '_python_path', path),
|
|
|
|
|
doc='The path to the remote Python interpreter.')
|
|
|
|
|
|
|
|
|
|
def _GetModuleSource(self, name):
|
|
|
|
|
return inspect.getsource(sys.modules[name])
|
|
|
|
|
def _GetModuleSource(self, killed, name):
|
|
|
|
|
if not killed:
|
|
|
|
|
return inspect.getsource(sys.modules[name])
|
|
|
|
|
|
|
|
|
|
def __init__(self, context, secure_unpickler=True):
|
|
|
|
|
super(LocalStream, self).__init__(context, secure_unpickler)
|
|
|
|
|
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE_SOURCE)
|
|
|
|
|
def __init__(self, context):
|
|
|
|
|
super(LocalStream, self).__init__(context)
|
|
|
|
|
self._permitted_modules = {}
|
|
|
|
|
self._unpickler.find_global = self._FindGlobal
|
|
|
|
|
self.AddHandleCB(self._GetModuleSource, handle=GET_MODULE)
|
|
|
|
|
|
|
|
|
|
def _FindGlobal(self, module_name, class_name):
|
|
|
|
|
'''
|
|
|
|
|
See the cPickle documentation: given a module and class name, determine
|
|
|
|
|
whether class referred to is safe for unpickling.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
module_name: str
|
|
|
|
|
class_name: str
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
classobj or type
|
|
|
|
|
'''
|
|
|
|
|
if module_name not in self._permitted_modules:
|
|
|
|
|
raise StreamError('context %r attempted to unpickle %r in module %r',
|
|
|
|
|
self._context, class_name, module_name)
|
|
|
|
|
return getattr(sys.modules[module_name], class_name)
|
|
|
|
|
|
|
|
|
|
def AllowModule(self, module_name):
|
|
|
|
|
'''
|
|
|
|
|
Add the given module to the list of permitted modules.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
module_name: str
|
|
|
|
|
'''
|
|
|
|
|
self._permitted_modules.add(module_name)
|
|
|
|
|
|
|
|
|
|
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
|
|
|
|
|
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
|
|
|
|
@ -441,7 +496,7 @@ class LocalStream(Stream):
|
|
|
|
|
self, self._wfd, self._rfd)
|
|
|
|
|
|
|
|
|
|
source = inspect.getsource(sys.modules[__name__])
|
|
|
|
|
source += '\nExternalContextImpl.Main(%r)\n' % (self._context.name,)
|
|
|
|
|
source += '\nExternalContextMain(%r)\n' % (self._context.name,)
|
|
|
|
|
compressed = zlib.compress(source)
|
|
|
|
|
|
|
|
|
|
preamble = str(len(compressed)) + '\n' + compressed
|
|
|
|
@ -470,9 +525,9 @@ class SSHStream(LocalStream):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Context(object):
|
|
|
|
|
"""
|
|
|
|
|
'''
|
|
|
|
|
Represents a remote context regardless of current connection method.
|
|
|
|
|
"""
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def __init__(self, manager, name=None, hostname=None, username=None):
|
|
|
|
|
self.manager = manager
|
|
|
|
@ -489,34 +544,44 @@ class Context(object):
|
|
|
|
|
self._stream = stream
|
|
|
|
|
return stream
|
|
|
|
|
|
|
|
|
|
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
|
|
|
|
|
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
|
|
|
|
|
kwargs)
|
|
|
|
|
handle = self._stream.AllocHandle()
|
|
|
|
|
def EnqueueAwaitReply(self, handle, deadline, data):
|
|
|
|
|
'''
|
|
|
|
|
Send a message to the given handle and wait for a response with an optional
|
|
|
|
|
timeout. The message contains (reply_handle, data), where reply_handle is
|
|
|
|
|
the handle on which this function expects its reply.
|
|
|
|
|
'''
|
|
|
|
|
Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data)
|
|
|
|
|
reply_handle = self._stream.AllocHandle()
|
|
|
|
|
reply_event = threading.Event()
|
|
|
|
|
container = []
|
|
|
|
|
|
|
|
|
|
def _Receive(handle, killed, data):
|
|
|
|
|
Log('%r._Receive(%r, %r, %r)', self, handle, killed, data)
|
|
|
|
|
def _Receive(killed, data):
|
|
|
|
|
Log('%r._Receive(%r, %r)', self, killed, data)
|
|
|
|
|
container.extend([killed, data])
|
|
|
|
|
reply_event.set()
|
|
|
|
|
|
|
|
|
|
self._stream.AddHandleCB(_Receive, handle, persist=False)
|
|
|
|
|
call = (handle, fn.__module__, fn.__name__, args, kwargs)
|
|
|
|
|
self._stream.Enqueue(CALL_FUNCTION, call)
|
|
|
|
|
self._stream.AddHandleCB(_Receive, reply_handle, persist=False)
|
|
|
|
|
self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data))
|
|
|
|
|
|
|
|
|
|
reply_event.wait(deadline)
|
|
|
|
|
if not reply_event.isSet():
|
|
|
|
|
self.Disconnect()
|
|
|
|
|
raise TimeoutError('deadline exceeded.')
|
|
|
|
|
|
|
|
|
|
Log('%r._Receive(): got reply, container is %r', self, container)
|
|
|
|
|
killed, data = container
|
|
|
|
|
|
|
|
|
|
if killed:
|
|
|
|
|
raise StreamError('lost connection during call.')
|
|
|
|
|
|
|
|
|
|
success, result = data
|
|
|
|
|
Log('%r._EnqueueAwaitReply(): got reply: %r', self, data)
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
|
|
|
|
|
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
|
|
|
|
|
kwargs)
|
|
|
|
|
|
|
|
|
|
call = (fn.__module__, fn.__name__, args, kwargs)
|
|
|
|
|
success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
|
|
|
|
|
|
|
|
|
|
if success:
|
|
|
|
|
return result
|
|
|
|
|
else:
|
|
|
|
@ -527,10 +592,6 @@ class Context(object):
|
|
|
|
|
def Call(self, fn, *args, **kwargs):
|
|
|
|
|
return self.CallWithDeadline(fn, None, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def Kill(self, deadline=30):
|
|
|
|
|
self.CallWithDeadline(os.kill, deadline,
|
|
|
|
|
-self.Call(os.getpgrp), signal.SIGTERM)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
bits = map(repr, filter(None, [self.name, self.hostname, self.username]))
|
|
|
|
|
return 'Context(%s)' % ', '.join(bits)
|
|
|
|
@ -543,11 +604,6 @@ class ContextManager(object):
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._scheduler = sched.scheduler(time.time, self.OneShot)
|
|
|
|
|
self._idle_timeout = 0
|
|
|
|
|
self._dead = False
|
|
|
|
|
self._kill_on_empty = False
|
|
|
|
|
|
|
|
|
|
self._poller = select.poll()
|
|
|
|
|
self._poller_fd_map = {}
|
|
|
|
|
|
|
|
|
@ -560,19 +616,15 @@ class ContextManager(object):
|
|
|
|
|
self._wake_rfd, self._wake_wfd = os.pipe()
|
|
|
|
|
self._poller.register(self._wake_rfd)
|
|
|
|
|
|
|
|
|
|
def SetKillOnEmpty(self, kill_on_empty=True):
|
|
|
|
|
'''
|
|
|
|
|
Indicate the main loop should exit when there are no remaining sessions
|
|
|
|
|
open.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
self._kill_on_empty = kill_on_empty
|
|
|
|
|
self._thread = threading.Thread(target=self.Loop, name='ContextManager')
|
|
|
|
|
self._thread.setDaemon(True)
|
|
|
|
|
self._thread.start()
|
|
|
|
|
self._dead = False
|
|
|
|
|
|
|
|
|
|
def Register(self, context):
|
|
|
|
|
'''
|
|
|
|
|
Put a context under control of this manager.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
self._contexts_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
self._contexts[context.name] = context
|
|
|
|
@ -590,21 +642,19 @@ class ContextManager(object):
|
|
|
|
|
Returns:
|
|
|
|
|
econtext.Context
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
context = Context(self, name)
|
|
|
|
|
context.SetStream(LocalStream(context)).Connect()
|
|
|
|
|
return self.Register(context)
|
|
|
|
|
|
|
|
|
|
def GetRemote(self, hostname, name=None, username=None):
|
|
|
|
|
def GetRemote(self, hostname, username=None, name=None):
|
|
|
|
|
'''
|
|
|
|
|
Return the named remote context, or create it if it doesn't exist.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if username is None:
|
|
|
|
|
username = getpass.getuser()
|
|
|
|
|
if name is None:
|
|
|
|
|
name = 'econtext[%s@%s:%d]' %\
|
|
|
|
|
(getpass.getuser(), socket.gethostname(), os.getpid())
|
|
|
|
|
(getpass.getuser(), os.getenv('HOSTNAME'), os.getpid())
|
|
|
|
|
|
|
|
|
|
context = Context(self, name, hostname, username)
|
|
|
|
|
context.SetStream(SSHStream(context)).Connect()
|
|
|
|
@ -612,22 +662,17 @@ class ContextManager(object):
|
|
|
|
|
|
|
|
|
|
def UpdateStreamIOState(self, stream):
|
|
|
|
|
'''
|
|
|
|
|
Update the manager's internal state regarding the specified stream. This
|
|
|
|
|
marks its FDs for polling as appropriate, and resets its idle counter.
|
|
|
|
|
Update the manager's internal state regarding the specified stream by
|
|
|
|
|
marking its FDs for polling as appropriate.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
stream: econtext.Stream
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r.UpdateStreamIOState(%r)', self, stream)
|
|
|
|
|
|
|
|
|
|
self._poller_changes_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
|
self._poller_changes[stream] = None
|
|
|
|
|
if self._idle_timeout:
|
|
|
|
|
if stream._sched_id:
|
|
|
|
|
self._scheduler.cancel(stream._sched_id)
|
|
|
|
|
self._scheduler.enter(self._idle_timeout, 0, stream.Disconnect, ())
|
|
|
|
|
finally:
|
|
|
|
|
self._poller_changes_lock.release()
|
|
|
|
|
os.write(self._wake_wfd, ' ')
|
|
|
|
@ -638,7 +683,6 @@ class ContextManager(object):
|
|
|
|
|
UpdateStreamIOState. Poller registration updates must be done in serial
|
|
|
|
|
with calls to its poll() method.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
Log('%r._DoChangedStreams()', self)
|
|
|
|
|
|
|
|
|
|
self._poller_changes_lock.acquire()
|
|
|
|
@ -651,20 +695,11 @@ class ContextManager(object):
|
|
|
|
|
for stream in changes:
|
|
|
|
|
alive, ifd, ofd, has_output = stream.GetIOState()
|
|
|
|
|
|
|
|
|
|
if not alive: # no fd = closed stream.
|
|
|
|
|
Log('here2')
|
|
|
|
|
if not alive:
|
|
|
|
|
for fd in (ifd, ofd):
|
|
|
|
|
try:
|
|
|
|
|
self._poller.unregister(fd)
|
|
|
|
|
Log('unregistered fd=%d from poller', fd)
|
|
|
|
|
except KeyError:
|
|
|
|
|
Log('failed to unregister fd=%d from poller', fd)
|
|
|
|
|
try:
|
|
|
|
|
del self._poller_fd_map[fd]
|
|
|
|
|
Log('unregistered fd=%d from poller map', fd)
|
|
|
|
|
except KeyError:
|
|
|
|
|
Log('failed to unregister fd=%d from poller map', fd)
|
|
|
|
|
del self._poller_fd_map[fd]
|
|
|
|
|
del self._contexts[stream._context]
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if has_output:
|
|
|
|
|
self._poller.register(ofd, select.POLLOUT)
|
|
|
|
@ -676,74 +711,38 @@ class ContextManager(object):
|
|
|
|
|
self._poller.register(ifd, select.POLLIN)
|
|
|
|
|
self._poller_fd_map[ifd] = stream
|
|
|
|
|
|
|
|
|
|
def OneShot(self, timeout=None):
|
|
|
|
|
'''
|
|
|
|
|
Poll once for I/O and return after all processing is complete, optionally
|
|
|
|
|
terminating after some number of seconds.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
timeout: int or float
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
if timeout == 0: # scheduler behaviour we don't require.
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
Log('%r.OneShot(): _poller_fd_map=%r', self, self._poller_fd_map)
|
|
|
|
|
|
|
|
|
|
for fd, event in self._poller.poll(timeout):
|
|
|
|
|
if fd == self._wake_rfd:
|
|
|
|
|
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
|
|
|
|
|
os.read(self._wake_rfd, 1)
|
|
|
|
|
self._DoChangedStreams()
|
|
|
|
|
break
|
|
|
|
|
elif event & select.POLLHUP:
|
|
|
|
|
Log('%r: POLLHUP on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Disconnect)
|
|
|
|
|
self._poller_fd_map[fd].Disconnect()
|
|
|
|
|
elif event & select.POLLIN:
|
|
|
|
|
Log('%r: POLLIN on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Receive)
|
|
|
|
|
self._poller_fd_map[fd].Receive()
|
|
|
|
|
elif event & select.POLLOUT:
|
|
|
|
|
Log('%r: POLLOUT on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Transmit)
|
|
|
|
|
self._poller_fd_map[fd].Transmit()
|
|
|
|
|
elif event & select.POLLNVAL:
|
|
|
|
|
# GAY
|
|
|
|
|
self._poller.unregister(fd)
|
|
|
|
|
|
|
|
|
|
def Loop(self):
|
|
|
|
|
'''
|
|
|
|
|
Handle stream events until Finalize() is called.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
while (not self._dead) or (self._kill_on_empty and not self._contexts):
|
|
|
|
|
# TODO: why the fuck is self._scheduler.empty() returning True?!
|
|
|
|
|
if not len(self._scheduler.queue):
|
|
|
|
|
self.OneShot()
|
|
|
|
|
else:
|
|
|
|
|
Log('self._scheduler.empty() -> %r', self._scheduler.empty())
|
|
|
|
|
Log('not not self._scheduler.queue -> %r',
|
|
|
|
|
not not self._scheduler.queue)
|
|
|
|
|
Log('%r._scheduler.run() -> %r', self, self._scheduler.queue)
|
|
|
|
|
raise SystemExit
|
|
|
|
|
self._scheduler.run()
|
|
|
|
|
|
|
|
|
|
def SetIdleTimeout(self, timeout):
|
|
|
|
|
'''
|
|
|
|
|
Set the number of seconds after which an idle stream connected to a remote
|
|
|
|
|
context is eligible for disconnection.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
timeout: int or float
|
|
|
|
|
'''
|
|
|
|
|
self._idle_timeout = timeout
|
|
|
|
|
while not self._dead:
|
|
|
|
|
Log('%r.Loop(): %r', self, self._poller_fd_map)
|
|
|
|
|
for fd, event in self._poller.poll():
|
|
|
|
|
if fd == self._wake_rfd:
|
|
|
|
|
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
|
|
|
|
|
os.read(self._wake_rfd, 1)
|
|
|
|
|
self._DoChangedStreams()
|
|
|
|
|
elif event & select.POLLHUP:
|
|
|
|
|
Log('%r: POLLHUP on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Disconnect)
|
|
|
|
|
self._poller_fd_map[fd].Disconnect()
|
|
|
|
|
elif event & select.POLLIN:
|
|
|
|
|
Log('%r: POLLIN on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Receive)
|
|
|
|
|
self._poller_fd_map[fd].Receive()
|
|
|
|
|
elif event & select.POLLOUT:
|
|
|
|
|
Log('%r: POLLOUT on %d', self, fd)
|
|
|
|
|
Log('%r: POLLOUT on %d; calling %r', self, fd,
|
|
|
|
|
self._poller_fd_map[fd].Transmit)
|
|
|
|
|
self._poller_fd_map[fd].Transmit()
|
|
|
|
|
elif event & select.POLLNVAL:
|
|
|
|
|
Log('%r: POLLNVAL for %d, unregistering it.', self, fd)
|
|
|
|
|
self._poller.unregister(fd)
|
|
|
|
|
|
|
|
|
|
def Finalize(self):
|
|
|
|
|
'''
|
|
|
|
|
Tell all active streams to disconnect.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
self._dead = True
|
|
|
|
|
self._contexts_lock.acquire()
|
|
|
|
|
try:
|
|
|
|
@ -758,100 +757,28 @@ class ContextManager(object):
|
|
|
|
|
return 'econtext.ContextManager(<contexts=%s>)' % (self._contexts.keys(),)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExternalContextImpl(object):
|
|
|
|
|
def Main(cls, context_name):
|
|
|
|
|
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
|
|
|
|
|
def ExternalContextMain(context_name):
|
|
|
|
|
Log('ExternalContextMain(%r)', context_name)
|
|
|
|
|
assert os.wait()[1] == 0, 'first stage did not exit cleanly.'
|
|
|
|
|
|
|
|
|
|
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
|
|
|
|
|
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
|
|
|
|
|
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
|
|
|
|
|
|
|
|
|
|
parent_host = os.getenv('SSH_CLIENT')
|
|
|
|
|
syslog.syslog('initializing (parent_host=%s)' % (parent_host,))
|
|
|
|
|
os.dup2(100, 0)
|
|
|
|
|
os.close(100)
|
|
|
|
|
|
|
|
|
|
os.dup2(100, 0)
|
|
|
|
|
os.close(100)
|
|
|
|
|
manager = ContextManager()
|
|
|
|
|
context = Context(manager, 'parent')
|
|
|
|
|
|
|
|
|
|
manager = ContextManager()
|
|
|
|
|
manager.SetKillOnEmpty()
|
|
|
|
|
context = Context(manager, 'parent')
|
|
|
|
|
stream = context.SetStream(Stream.FromFDs(context, rfd=0, wfd=1))
|
|
|
|
|
manager.Register(context)
|
|
|
|
|
|
|
|
|
|
stream = context.SetStream(SlaveStream.FromFDs(context, rfd=0, wfd=1))
|
|
|
|
|
manager.Register(context)
|
|
|
|
|
for call_info in Channel(stream, CALL_FUNCTION):
|
|
|
|
|
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)
|
|
|
|
|
reply_handle, mod_name, func_name, args, kwargs = call_info
|
|
|
|
|
fn = getattr(__import__(mod_name), func_name)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
manager.Loop()
|
|
|
|
|
except StreamError, e:
|
|
|
|
|
syslog.syslog('exit: ' + str(e))
|
|
|
|
|
os.kill(-os.getpgrp(), signal.SIGKILL)
|
|
|
|
|
Main = classmethod(Main)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return 'ExternalContextImpl(%r)' % (self.name,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# Simple interface.
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
def Init(idle_secs=60*60):
|
|
|
|
|
'''
|
|
|
|
|
Initialize the simple interface.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
# Seconds to keep an unused context alive or None for infinite.
|
|
|
|
|
idle_secs: 3600 or None
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
global _manager
|
|
|
|
|
global _manager_thread
|
|
|
|
|
|
|
|
|
|
if _manager:
|
|
|
|
|
return _manager
|
|
|
|
|
|
|
|
|
|
_manager = ContextManager()
|
|
|
|
|
_manager.SetIdleTimeout(idle_secs)
|
|
|
|
|
_manager_thread = threading.Thread(target=_manager.Loop)
|
|
|
|
|
_manager_thread.setDaemon(True)
|
|
|
|
|
_manager_thread.start()
|
|
|
|
|
atexit.register(Finalize)
|
|
|
|
|
return _manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Finalize():
|
|
|
|
|
global _manager
|
|
|
|
|
global _manager_thread
|
|
|
|
|
|
|
|
|
|
if _manager is not None:
|
|
|
|
|
_manager.Finalize()
|
|
|
|
|
_manager = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def CallWithDeadline(hostname, username, fn, deadline, *args, **kwargs):
|
|
|
|
|
'''
|
|
|
|
|
Make a function call in the context of a remote host. Set a maximum deadline
|
|
|
|
|
in seconds after which it is assumed the call failed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
# Hostname or address of remote host.
|
|
|
|
|
hostname: str
|
|
|
|
|
# Username to connect as, or None for current user.
|
|
|
|
|
username: str or None
|
|
|
|
|
# Seconds until we assume the call has failed.
|
|
|
|
|
deadline: float or None
|
|
|
|
|
# The function to execute in the remote context.
|
|
|
|
|
fn: staticmethod or classmethod or types.FunctionType
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
# Function's return value.
|
|
|
|
|
object
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
context = Init().GetRemote(hostname, username=username)
|
|
|
|
|
return context.CallWithDeadline(fn, deadline, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Call(hostname, username, fn, *args, **kwargs):
|
|
|
|
|
'''
|
|
|
|
|
Like CallWithDeadline, but with no deadline.
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
return CallWithDeadline(hostname, username, fn, None, *args, **kwargs)
|
|
|
|
|
stream.Enqueue(reply_handle, (True, fn(*args, **kwargs)))
|
|
|
|
|
except Exception, e:
|
|
|
|
|
stram.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
|
|
|
|
|