Split into core and master modules.

pull/35/head
David Wilson 8 years ago
parent dc231847a0
commit ae3316b985

@ -1 +1 @@
from econtext.core import * from econtext.core import * # NOQA

@ -7,11 +7,8 @@ Python external execution contexts.
import Queue import Queue
import cPickle import cPickle
import cStringIO import cStringIO
import commands
import getpass
import hmac import hmac
import imp import imp
import inspect
import logging import logging
import os import os
import random import random
@ -20,7 +17,6 @@ import sha
import socket import socket
import struct import struct
import sys import sys
import textwrap
import threading import threading
import traceback import traceback
import types import types
@ -29,7 +25,6 @@ import zlib
LOG = logging.getLogger('econtext') LOG = logging.getLogger('econtext')
IOLOG = logging.getLogger('econtext.io') IOLOG = logging.getLogger('econtext.io')
RLOG = logging.getLogger('econtext.ctx')
GET_MODULE = 100L GET_MODULE = 100L
CALL_FUNCTION = 101L CALL_FUNCTION = 101L
@ -91,25 +86,6 @@ def write_all(fd, s):
return written return written
def CreateChild(*args):
"""Create a child process whose stdin/stdout is connected to a socket,
returning `(pid, socket_obj)`."""
parentfp, childfp = socket.socketpair()
pid = os.fork()
if not pid:
os.dup2(childfp.fileno(), 0)
os.dup2(childfp.fileno(), 1)
childfp.close()
parentfp.close()
os.execvp(args[0], args)
raise SystemExit
childfp.close()
LOG.debug('CreateChild() child %d fd %d, parent %d, args %r',
pid, parentfp.fileno(), os.getpid(), args)
return pid, parentfp
class Channel(object): class Channel(object):
def __init__(self, context, handle): def __init__(self, context, handle):
self._context = context self._context = context
@ -166,20 +142,29 @@ class SlaveModuleImporter(object):
:param context: Context to communicate via. :param context: Context to communicate via.
""" """
_absent = None
def __init__(self, context): def __init__(self, context):
self._context = context self._context = context
self._absent = set(['econtext.econtext', 'econtext.logging']) self._lock = threading.RLock()
self._ignore = []
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
LOG.debug('SlaveModuleImporter.find_module(%r)', fullname) LOG.debug('SlaveModuleImporter.find_module(%r)', fullname)
if fullname in self._absent: if fullname in self._absent or fullname in self._ignore:
return None return None
self._lock.acquire()
try: try:
imp.find_module(fullname) self._ignore.append(fullname)
except ImportError: try:
LOG.debug('find_module(%r) returning self', fullname) __import__(fullname, fromlist=['*'])
return self except ImportError:
LOG.debug('find_module(%r) returning self', fullname)
return self
finally:
self._ignore.pop()
self._lock.release()
def load_module(self, fullname): def load_module(self, fullname):
LOG.debug('SlaveModuleImporter.load_module(%r)', fullname) LOG.debug('SlaveModuleImporter.load_module(%r)', fullname)
@ -201,46 +186,6 @@ class SlaveModuleImporter(object):
return mod return mod
class MasterModuleResponder(object):
def __init__(self, context):
self._context = context
self._context.AddHandleCB(self.GetModule, handle=GET_MODULE)
def _GetAbsent(self, module, prefix):
return [k for k, v in sys.modules.iteritems()
if v is None and k.startswith(prefix)]
def GetModule(self, data):
if data == _DEAD:
return
reply_to, fullname = data
LOG.debug('MasterModuleResponder.GetModule(%r, %r)', reply_to, fullname)
try:
module = __import__(fullname, fromlist=[''])
is_pkg = getattr(module, '__path__', None) is not None
path = inspect.getsourcefile(module)
try:
source = inspect.getsource(module)
except IOError:
if not is_pkg:
raise
source = '\n'
if is_pkg:
prefix = module.__name__ + '.'
absent = self._GetAbsent(module, prefix)
else:
absent = []
compressed = zlib.compress(source)
reply = (is_pkg, absent, path, compressed)
self._context.Enqueue(reply_to, reply)
except Exception:
LOG.exception('While importing %r', fullname)
self._context.Enqueue(reply_to, None)
class LogHandler(logging.Handler): class LogHandler(logging.Handler):
def __init__(self, context): def __init__(self, context):
logging.Handler.__init__(self) logging.Handler.__init__(self)
@ -260,20 +205,6 @@ class LogHandler(logging.Handler):
self.local.in_commit = False self.local.in_commit = False
class LogForwarder(object):
def __init__(self, context):
self._context = context
self._context.AddHandleCB(self.ForwardLog, handle=FORWARD_LOG)
self._log = RLOG.getChild(self._context.name)
def ForwardLog(self, data):
if data == _DEAD:
return
name, level, s = data
self._log.log(level, '%s: %s', name, s)
class Side(object): class Side(object):
def __init__(self, stream, fd): def __init__(self, stream, fd):
self.stream = stream self.stream = stream
@ -442,98 +373,6 @@ class Stream(BasicStream):
return '%s(<context=%r>)' % (self.__class__.__name__, self._context) return '%s(<context=%r>)' % (self.__class__.__name__, self._context)
class LocalStream(Stream):
"""
Base for streams capable of starting new slaves.
"""
#: The path to the remote Python interpreter.
python_path = sys.executable
def __init__(self, context):
super(LocalStream, self).__init__(context)
self._permitted_classes = set([('econtext.core', 'CallError')])
def _FindGlobal(self, module_name, class_name):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in self._permitted_classes:
raise StreamError('%r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def AllowClass(self, module_name, class_name):
"""Add `module_name` to the list of permitted modules."""
self._permitted_modules.add((module_name, class_name))
# base64'd and passed to 'python -c'. It forks, dups 0->100, creates a
# pipe, then execs a new interpreter with a custom argv. CONTEXT_NAME is
# replaced with the context name. Optimized for size.
def _FirstStage():
import os,sys,zlib
R,W=os.pipe()
if os.fork():
os.dup2(0,100)
os.dup2(R,0)
os.close(R)
os.close(W)
os.execv(sys.executable,('econtext:'+CONTEXT_NAME,))
else:
os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input())))
print 'OK'
sys.exit(0)
def GetBootCommand(self):
name = self._context.remote_name
if name is None:
name = '%s@%s:%d'
name %= (getpass.getuser(), socket.gethostname(), os.getpid())
source = inspect.getsource(self._FirstStage)
source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:]))
source = source.replace(' ', '\t')
source = source.replace('CONTEXT_NAME', repr(name))
encoded = source.encode('base64').replace('\n', '')
return [self.python_path, '-c',
'exec "%s".decode("base64")' % (encoded,)]
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._context)
def Connect(self):
LOG.debug('%r.Connect()', self)
pid, sock = CreateChild(*self.GetBootCommand())
self.read_side = Side(self, os.dup(sock.fileno()))
self.write_side = self.read_side
sock.close()
LOG.debug('%r.Connect(): child process stdin/stdout=%r',
self, self.read_side.fd)
source = inspect.getsource(sys.modules[__name__])
source += '\nExternalContext().main(%r, %r, %r)\n' % (
self._context.name,
self._context.key,
LOG.level or logging.getLogger().level or logging.INFO,
)
compressed = zlib.compress(source)
preamble = str(len(compressed)) + '\n' + compressed
write_all(self.write_side.fd, preamble)
assert os.read(self.read_side.fd, 3) == 'OK\n'
class SSHStream(LocalStream):
#: The path to the SSH binary.
ssh_path = 'ssh'
def GetBootCommand(self):
bits = [self.ssh_path]
if self._context.username:
bits += ['-l', self._context.username]
bits.append(self._context.hostname)
base = super(SSHStream, self).GetBootCommand()
return bits + map(commands.mkarg, base)
class Context(object): class Context(object):
""" """
Represent a remote context regardless of connection method. Represent a remote context regardless of connection method.
@ -555,9 +394,6 @@ class Context(object):
self._handle_map = {} self._handle_map = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self.responder = MasterModuleResponder(self)
self.log_forwarder = LogForwarder(self)
def Disconnect(self): def Disconnect(self):
self.stream = None self.stream = None
if self.finalize_on_disconnect: if self.finalize_on_disconnect:
@ -655,22 +491,6 @@ class Waker(BasicStream):
os.read(self.read_side.fd, 1) os.read(self.read_side.fd, 1)
class Listener(BasicStream):
def __init__(self, broker, address=None, backlog=30):
self._broker = broker
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.bind(address or ('0.0.0.0', 0))
self._sock.listen(backlog)
self._listen_addr = self._sock.getsockname()
self.read_side = Side(self, self._sock.fileno())
broker.UpdateStream(self)
def Receive(self):
sock, addr = self._sock.accept()
context = Context(self._broker, name=addr)
Stream(context).Accept(sock.fileno(), sock.fileno())
class IoLogger(BasicStream): class IoLogger(BasicStream):
_buf = '' _buf = ''
@ -721,10 +541,6 @@ class Broker(object):
name='econtext-broker') name='econtext-broker')
self._thread.start() self._thread.start()
def CreateListener(self, address=None, backlog=30):
"""Listen on `address `for connections from newly spawned contexts."""
self._listener = Listener(self, address, backlog)
def _UpdateStream(self, stream): def _UpdateStream(self, stream):
IOLOG.debug('_UpdateStream(%r)', stream) IOLOG.debug('_UpdateStream(%r)', stream)
self._lock.acquire() self._lock.acquire()
@ -755,26 +571,6 @@ class Broker(object):
self._contexts[context.name] = context self._contexts[context.name] = context
return context return context
def GetLocal(self, name='default'):
"""Get the named context running on the local machine, creating it if
it does not exist."""
context = Context(self, name)
context.stream = LocalStream(context)
context.stream.Connect()
return self.Register(context)
def GetRemote(self, hostname, username, name=None, python_path=None):
"""Get the named remote context, creating it if it does not exist."""
if name is None:
name = hostname
context = Context(self, name, hostname, username)
context.stream = SSHStream(context)
if python_path:
context.stream.python_path = python_path
context.stream.Connect()
return self.Register(context)
def _CallAndUpdate(self, stream, func): def _CallAndUpdate(self, stream, func):
try: try:
func() func()
@ -827,6 +623,8 @@ class ExternalContext(object):
def _FixupMainModule(self): def _FixupMainModule(self):
main = sys.modules['__main__'] main = sys.modules['__main__']
main.__path__ = [] main.__path__ = []
main.core = main
sys.modules['econtext'] = main sys.modules['econtext'] = main
sys.modules['econtext.core'] = main sys.modules['econtext.core'] = main
@ -890,7 +688,7 @@ class ExternalContext(object):
except Exception, e: except Exception, e:
self.context.Enqueue(reply_to, CallError(e)) self.context.Enqueue(reply_to, CallError(e))
def main(self, context_name, key, log_level): def main(self, context_name, key, log_level, absent):
self._ReapFirstStage() self._ReapFirstStage()
self._FixupMainModule() self._FixupMainModule()
self._SetupMaster(key) self._SetupMaster(key)
@ -899,8 +697,9 @@ class ExternalContext(object):
self._SetupStdio() self._SetupStdio()
# signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize()) # signal.signal(signal.SIGINT, lambda *_: self.broker.Finalize())
SlaveModuleImporter._absent = set(absent)
self.broker.Register(self.context) self.broker.Register(self.context)
self._DispatchCalls() self._DispatchCalls()
self.broker.Wait() self.broker.Wait()
LOG.debug('ExternalContext.main() exitting') LOG.debug('ExternalContext.main() exitting')

@ -2,6 +2,7 @@
import logging import logging
import econtext import econtext
import econtext.master
def log_to_file(path, level=logging.DEBUG): def log_to_file(path, level=logging.DEBUG):
@ -12,7 +13,7 @@ def log_to_file(path, level=logging.DEBUG):
def run_with_broker(func, *args, **kwargs): def run_with_broker(func, *args, **kwargs):
broker = econtext.Broker() broker = econtext.master.Broker()
try: try:
return func(broker, *args, **kwargs) return func(broker, *args, **kwargs)
finally: finally:

Loading…
Cancel
Save