diff --git a/econtext/master.py b/econtext/master.py new file mode 100644 index 00000000..a2f31fab --- /dev/null +++ b/econtext/master.py @@ -0,0 +1,240 @@ + +import commands +import getpass +import inspect +import logging +import os +import pkgutil +import re +import socket +import sys +import textwrap +import zlib + +import econtext.core + + +LOG = logging.getLogger('econtext') +IOLOG = logging.getLogger('econtext.io') +RLOG = logging.getLogger('econtext.ctx') + + +def GetChildModules(module, prefix): + it = pkgutil.iter_modules(module.__path__, prefix) + return [name for _, name, _ in it] + + +def CreateChild(*args): + """Create a child process whose stdin/stdout is connected to a socket, + returning `(pid, socket_obj)`.""" + parentfp, childfp = socket.socketpair() + pid = os.fork() + if not pid: + os.dup2(childfp.fileno(), 0) + os.dup2(childfp.fileno(), 1) + childfp.close() + parentfp.close() + os.execvp(args[0], args) + raise SystemExit + + childfp.close() + LOG.debug('CreateChild() child %d fd %d, parent %d, args %r', + pid, parentfp.fileno(), os.getpid(), args) + return pid, parentfp + + +class Listener(econtext.core.BasicStream): + def __init__(self, broker, address=None, backlog=30): + self._broker = broker + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.bind(address or ('0.0.0.0', 0)) + self._sock.listen(backlog) + self._listen_addr = self._sock.getsockname() + self.read_side = econtext.core.Side(self, self._sock.fileno()) + broker.UpdateStream(self) + + def Receive(self): + sock, addr = self._sock.accept() + context = Context(self._broker, name=addr) + stream = econtext.core.Stream(context) + stream.Accept(sock.fileno(), sock.fileno()) + + +class LogForwarder(object): + def __init__(self, context): + self._context = context + self._context.AddHandleCB(self.ForwardLog, + handle=econtext.core.FORWARD_LOG) + self._log = RLOG.getChild(self._context.name) + + def ForwardLog(self, data): + if data == econtext.core._DEAD: + return + + name, level, s = data + self._log.log(level, '%s: %s', name, s) + + +class ModuleResponder(object): + def __init__(self, context): + self._context = context + self._context.AddHandleCB(self.GetModule, + handle=econtext.core.GET_MODULE) + + def GetModule(self, data): + if data == econtext.core._DEAD: + return + + reply_to, fullname = data + LOG.debug('GetModule(%r, %r)', reply_to, fullname) + try: + module = __import__(fullname, fromlist=['']) + is_pkg = getattr(module, '__path__', None) is not None + path = inspect.getsourcefile(module) + try: + source = inspect.getsource(module) + except IOError: + if not is_pkg: + raise + source = '\n' + + if is_pkg: + prefix = module.__name__ + '.' + present = GetChildModules(module, prefix) + else: + present = None + + compressed = zlib.compress(source) + reply = (is_pkg, present, path, compressed) + self._context.Enqueue(reply_to, reply) + except Exception: + LOG.exception('While importing %r', fullname) + self._context.Enqueue(reply_to, None) + + +class LocalStream(econtext.core.Stream): + """ + Base for streams capable of starting new slaves. + """ + #: The path to the remote Python interpreter. + python_path = sys.executable + + def __init__(self, context): + super(LocalStream, self).__init__(context) + self._permitted_classes = set([('econtext.core', 'CallError')]) + + def _FindGlobal(self, module_name, class_name): + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" + if (module_name, class_name) not in self._permitted_classes: + raise econtext.core.StreamError( + '%r attempted to unpickle %r in module %r', + self._context, class_name, module_name) + return getattr(sys.modules[module_name], class_name) + + def AllowClass(self, module_name, class_name): + """Add `module_name` to the list of permitted modules.""" + self._permitted_modules.add((module_name, class_name)) + + # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a + # pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is + # replaced with the context name. Optimized for size. + def _FirstStage(): + import os,sys,zlib + R,W=os.pipe() + if os.fork(): + os.dup2(0,100) + os.dup2(R,0) + os.close(R) + os.close(W) + os.execv(sys.executable,('econtext:'+CONTEXT_NAME,)) + else: + os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) + print 'OK' + sys.exit(0) + + def GetBootCommand(self): + name = self._context.remote_name + if name is None: + name = '%s@%s:%d' + name %= (getpass.getuser(), socket.gethostname(), os.getpid()) + + source = inspect.getsource(self._FirstStage) + source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) + source = source.replace(' ', '\t') + source = source.replace('CONTEXT_NAME', repr(name)) + encoded = source.encode('base64').replace('\n', '') + return [self.python_path, '-c', + 'exec "%s".decode("base64")' % (encoded,)] + + def __repr__(self): + return '%s(%s)' % (self.__class__.__name__, self._context) + + def Connect(self): + LOG.debug('%r.Connect()', self) + pid, sock = CreateChild(*self.GetBootCommand()) + self.read_side = econtext.core.Side(self, os.dup(sock.fileno())) + self.write_side = self.read_side + sock.close() + LOG.debug('%r.Connect(): child process stdin/stdout=%r', + self, self.read_side.fd) + + source = inspect.getsource(econtext.core) + source += '\nExternalContext().main%r\n' % (( + self._context.key, + LOG.level or logging.getLogger().level or logging.INFO, + ),) + + compressed = zlib.compress(source) + preamble = str(len(compressed)) + '\n' + compressed + econtext.core.write_all(self.write_side.fd, preamble) + assert os.read(self.read_side.fd, 3) == 'OK\n' + + +class SSHStream(LocalStream): + #: The path to the SSH binary. + ssh_path = 'ssh' + + def GetBootCommand(self): + bits = [self.ssh_path] + if self._context.username: + bits += ['-l', self._context.username] + bits.append(self._context.hostname) + base = super(SSHStream, self).GetBootCommand() + return bits + map(commands.mkarg, base) + + +class Broker(econtext.core.Broker): + def CreateListener(self, address=None, backlog=30): + """Listen on `address `for connections from newly spawned contexts.""" + self._listener = Listener(self, address, backlog) + + def GetLocal(self, name='default'): + """Get the named context running on the local machine, creating it if + it does not exist.""" + context = Context(self, name) + context.stream = LocalStream(context) + context.stream.Connect() + return self.Register(context) + + def GetRemote(self, hostname, username, name=None, python_path=None): + """Get the named remote context, creating it if it does not exist.""" + if name is None: + name = hostname + + context = Context(self, name, hostname, username) + context.stream = SSHStream(context) + if python_path: + context.stream.python_path = python_path + context.stream.Connect() + return self.Register(context) + + +class Context(econtext.core.Context): + def __init__(self, *args, **kwargs): + super(Context, self).__init__(*args, **kwargs) + self.responder = ModuleResponder(self) + self.log_forwarder = LogForwarder(self) + + def Disconnect(self): + self.stream = None