diff --git a/econtext/__init__.py b/econtext/__init__.py index c5461759..014c0367 100644 --- a/econtext/__init__.py +++ b/econtext/__init__.py @@ -1 +1 @@ -from econtext.core import * +from econtext.core import * # NOQA diff --git a/econtext/core.py b/econtext/core.py index 22df90fa..19c909b1 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -7,11 +7,8 @@ Python external execution contexts. import Queue import cPickle import cStringIO -import commands -import getpass import hmac import imp -import inspect import logging import os import random @@ -20,7 +17,6 @@ import sha import socket import struct import sys -import textwrap import threading import traceback import types @@ -29,7 +25,6 @@ import zlib LOG = logging.getLogger('econtext') IOLOG = logging.getLogger('econtext.io') -RLOG = logging.getLogger('econtext.ctx') GET_MODULE = 100L CALL_FUNCTION = 101L @@ -91,25 +86,6 @@ def write_all(fd, s): return written -def CreateChild(*args): - """Create a child process whose stdin/stdout is connected to a socket, - returning `(pid, socket_obj)`.""" - parentfp, childfp = socket.socketpair() - pid = os.fork() - if not pid: - os.dup2(childfp.fileno(), 0) - os.dup2(childfp.fileno(), 1) - childfp.close() - parentfp.close() - os.execvp(args[0], args) - raise SystemExit - - childfp.close() - LOG.debug('CreateChild() child %d fd %d, parent %d, args %r', - pid, parentfp.fileno(), os.getpid(), args) - return pid, parentfp - - class Channel(object): def __init__(self, context, handle): self._context = context @@ -166,20 +142,29 @@ class SlaveModuleImporter(object): :param context: Context to communicate via. """ + _absent = None + def __init__(self, context): self._context = context - self._absent = set(['econtext.econtext', 'econtext.logging']) + self._lock = threading.RLock() + self._ignore = [] def find_module(self, fullname, path=None): LOG.debug('SlaveModuleImporter.find_module(%r)', fullname) - if fullname in self._absent: + if fullname in self._absent or fullname in self._ignore: return None + self._lock.acquire() try: - imp.find_module(fullname) - except ImportError: - LOG.debug('find_module(%r) returning self', fullname) - return self + self._ignore.append(fullname) + try: + __import__(fullname, fromlist=['*']) + except ImportError: + LOG.debug('find_module(%r) returning self', fullname) + return self + finally: + self._ignore.pop() + self._lock.release() def load_module(self, fullname): LOG.debug('SlaveModuleImporter.load_module(%r)', fullname) @@ -201,46 +186,6 @@ class SlaveModuleImporter(object): return mod -class MasterModuleResponder(object): - def __init__(self, context): - self._context = context - self._context.AddHandleCB(self.GetModule, handle=GET_MODULE) - - def _GetAbsent(self, module, prefix): - return [k for k, v in sys.modules.iteritems() - if v is None and k.startswith(prefix)] - - def GetModule(self, data): - if data == _DEAD: - return - - reply_to, fullname = data - LOG.debug('MasterModuleResponder.GetModule(%r, %r)', reply_to, fullname) - try: - module = __import__(fullname, fromlist=['']) - is_pkg = getattr(module, '__path__', None) is not None - path = inspect.getsourcefile(module) - try: - source = inspect.getsource(module) - except IOError: - if not is_pkg: - raise - source = '\n' - - if is_pkg: - prefix = module.__name__ + '.' - absent = self._GetAbsent(module, prefix) - else: - absent = [] - - compressed = zlib.compress(source) - reply = (is_pkg, absent, path, compressed) - self._context.Enqueue(reply_to, reply) - except Exception: - LOG.exception('While importing %r', fullname) - self._context.Enqueue(reply_to, None) - - class LogHandler(logging.Handler): def __init__(self, context): logging.Handler.__init__(self) @@ -260,20 +205,6 @@ class LogHandler(logging.Handler): self.local.in_commit = False -class LogForwarder(object): - def __init__(self, context): - self._context = context - self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG) - self._log = RLOG.getChild(self._context.name) - - def ForwardLog(self, data): - if data == _DEAD: - return - - name, level, s = data - self._log.log(level, '%s: %s', name, s) - - class Side(object): def __init__(self, stream, fd): self.stream = stream @@ -442,98 +373,6 @@ class Stream(BasicStream): return '%s()' % (self.__class__.__name__, self._context) -class LocalStream(Stream): - """ - Base for streams capable of starting new slaves. - """ - #: The path to the remote Python interpreter. - python_path = sys.executable - - def __init__(self, context): - super(LocalStream, self).__init__(context) - self._permitted_classes = set([('econtext.core', 'CallError')]) - - def _FindGlobal(self, module_name, class_name): - """Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted.""" - if (module_name, class_name) not in self._permitted_classes: - raise StreamError('%r attempted to unpickle %r in module %r', - self._context, class_name, module_name) - return getattr(sys.modules[module_name], class_name) - - def AllowClass(self, module_name, class_name): - """Add `module_name` to the list of permitted modules.""" - self._permitted_modules.add((module_name, class_name)) - - # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a - # pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is - # replaced with the context name. Optimized for size. - def _FirstStage(): - import os,sys,zlib - R,W=os.pipe() - if os.fork(): - os.dup2(0,100) - os.dup2(R,0) - os.close(R) - os.close(W) - os.execv(sys.executable,('econtext:'+CONTEXT_NAME,)) - else: - os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) - print 'OK' - sys.exit(0) - - def GetBootCommand(self): - name = self._context.remote_name - if name is None: - name = '%s@%s:%d' - name %= (getpass.getuser(), socket.gethostname(), os.getpid()) - - source = inspect.getsource(self._FirstStage) - source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) - source = source.replace(' ', '\t') - source = source.replace('CONTEXT_NAME', repr(name)) - encoded = source.encode('base64').replace('\n', '') - return [self.python_path, '-c', - 'exec "%s".decode("base64")' % (encoded,)] - - def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, self._context) - - def Connect(self): - LOG.debug('%r.Connect()', self) - pid, sock = CreateChild(*self.GetBootCommand()) - self.read_side = Side(self, os.dup(sock.fileno())) - self.write_side = self.read_side - sock.close() - LOG.debug('%r.Connect(): child process stdin/stdout=%r', - self, self.read_side.fd) - - source = inspect.getsource(sys.modules[__name__]) - source += '\nExternalContext().main(%r, %r, %r)\n' % ( - self._context.name, - self._context.key, - LOG.level or logging.getLogger().level or logging.INFO, - ) - - compressed = zlib.compress(source) - preamble = str(len(compressed)) + '\n' + compressed - write_all(self.write_side.fd, preamble) - assert os.read(self.read_side.fd, 3) == 'OK\n' - - -class SSHStream(LocalStream): - #: The path to the SSH binary. - ssh_path = 'ssh' - - def GetBootCommand(self): - bits = [self.ssh_path] - if self._context.username: - bits += ['-l', self._context.username] - bits.append(self._context.hostname) - base = super(SSHStream, self).GetBootCommand() - return bits + map(commands.mkarg, base) - - class Context(object): """ Represent a remote context regardless of connection method. @@ -555,9 +394,6 @@ class Context(object): self._handle_map = {} self._lock = threading.Lock() - self.responder = MasterModuleResponder(self) - self.log_forwarder = LogForwarder(self) - def Disconnect(self): self.stream = None if self.finalize_on_disconnect: @@ -655,22 +491,6 @@ class Waker(BasicStream): os.read(self.read_side.fd, 1) -class Listener(BasicStream): - def __init__(self, broker, address=None, backlog=30): - self._broker = broker - self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._sock.bind(address or ('0.0.0.0', 0)) - self._sock.listen(backlog) - self._listen_addr = self._sock.getsockname() - self.read_side = Side(self, self._sock.fileno()) - broker.UpdateStream(self) - - def Receive(self): - sock, addr = self._sock.accept() - context = Context(self._broker, name=addr) - Stream(context).Accept(sock.fileno(), sock.fileno()) - - class IoLogger(BasicStream): _buf = '' @@ -721,10 +541,6 @@ class Broker(object): name='econtext-broker') self._thread.start() - def CreateListener(self, address=None, backlog=30): - """Listen on `address `for connections from newly spawned contexts.""" - self._listener = Listener(self, address, backlog) - def _UpdateStream(self, stream): IOLOG.debug('_UpdateStream(%r)', stream) self._lock.acquire() @@ -755,26 +571,6 @@ class Broker(object): self._contexts[context.name] = context return context - def GetLocal(self, name='default'): - """Get the named context running on the local machine, creating it if - it does not exist.""" - context = Context(self, name) - context.stream = LocalStream(context) - context.stream.Connect() - return self.Register(context) - - def GetRemote(self, hostname, username, name=None, python_path=None): - """Get the named remote context, creating it if it does not exist.""" - if name is None: - name = hostname - - context = Context(self, name, hostname, username) - context.stream = SSHStream(context) - if python_path: - context.stream.python_path = python_path - context.stream.Connect() - return self.Register(context) - def _CallAndUpdate(self, stream, func): try: func() @@ -827,6 +623,8 @@ class ExternalContext(object): def _FixupMainModule(self): main = sys.modules['__main__'] main.__path__ = [] + main.core = main + sys.modules['econtext'] = main sys.modules['econtext.core'] = main @@ -890,7 +688,7 @@ class ExternalContext(object): except Exception, e: self.context.Enqueue(reply_to, CallError(e)) - def main(self, context_name, key, log_level): + def main(self, context_name, key, log_level, absent): self._ReapFirstStage() self._FixupMainModule() self._SetupMaster(key) @@ -899,8 +697,9 @@ class ExternalContext(object): self._SetupStdio() # signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize()) - + SlaveModuleImporter._absent = set(absent) self.broker.Register(self.context) + self._DispatchCalls() self.broker.Wait() LOG.debug('ExternalContext.main() exitting') diff --git a/econtext/utils.py b/econtext/utils.py index 73ecaaa2..91671794 100644 --- a/econtext/utils.py +++ b/econtext/utils.py @@ -2,6 +2,7 @@ import logging import econtext +import econtext.master def log_to_file(path, level=logging.DEBUG): @@ -12,7 +13,7 @@ def log_to_file(path, level=logging.DEBUG): def run_with_broker(func, *args, **kwargs): - broker = econtext.Broker() + broker = econtext.master.Broker() try: return func(broker, *args, **kwargs) finally: