diff --git a/docs/howitworks.rst b/docs/howitworks.rst index fe812e41..4875c17d 100644 --- a/docs/howitworks.rst +++ b/docs/howitworks.rst @@ -97,6 +97,17 @@ comments, while preserving line numbers. This reduces the compressed payload by around 20%. +Preserving The `econtext.core` Source +##################################### + +One final trick is implemented in the first stage: after bootstrapping the new +slave, it writes a duplicate copy of the `econtext.core` source it just used to +bootstrap it back into another pipe connected to the slave. The slave's module +importer cache is initialized with a copy of the source, so that subsequent +bootstraps of slave-of-slaves do not require the source to be fetched from the +master a second time. + + Signalling Success ################## @@ -231,19 +242,23 @@ Stream Protocol Once connected, a basic framing protocol is used to communicate between master and slave: -+----------------+------+------------------------------------------------------+ -| Field | Size | Description | -+================+======+======================================================+ -| ``hmac`` | 20 | SHA-1 over (``context_id || handle || len || data``) | -+----------------+------+------------------------------------------------------+ -| ``handle`` | 4 | Integer target handle in recipient. | -+----------------+------+------------------------------------------------------+ -| ``reply_to`` | 4 | Integer response target ID. | -+----------------+------+------------------------------------------------------+ -| ``length`` | 4 | Message length | -+----------------+------+------------------------------------------------------+ -| ``data`` | n/a | Pickled message data. | -+----------------+------+------------------------------------------------------+ ++--------------------+------+------------------------------------------------------+ +| Field | Size | Description | ++====================+======+======================================================+ +| ``hmac`` | 20 | SHA-1 over remaining fields. | ++--------------------+------+------------------------------------------------------+ +| ``dst_id`` | 4 | Integer source context ID. | ++--------------------+------+------------------------------------------------------+ +| ``src_id`` | 4 | Integer source context ID. | ++--------------------+------+------------------------------------------------------+ +| ``handle`` | 4 | Integer target handle in recipient. | ++--------------------+------+------------------------------------------------------+ +| ``reply_to`` | 4 | Integer response target ID. | ++--------------------+------+------------------------------------------------------+ +| ``length`` | 4 | Message length | ++--------------------+------+------------------------------------------------------+ +| ``data`` | n/a | Pickled message data. | ++--------------------+------+------------------------------------------------------+ Masters listen on the following handles: @@ -268,6 +283,29 @@ Slaves listen on the following handles: imports ``mod_name``, then attempts to execute `class_name.func_name(\*args, \**kwargs)`. +.. data:: econtext.core.ADD_ROUTE + + Receives `(target_id, via_id)` integer tuples, describing how messages + arriving at this context on any Stream should be forwarded on the stream + associated with the Context `via_id` such that they are eventually + delivered to the target Context. + + This message is necessary to inform intermediary contexts of the existence + of a downstream Context, as they do not otherwise parse traffic they are + fowarding to their downstream contexts that may cause new contexts to be + established. + + Given a chain `master -> ssh1 -> sudo1`, no `ADD_ROUTE` message is + necessary, since :py:class:`econtext.core.Router` in the `ssh` context can + arrange to update its routes while setting up the new slave during + `proxy_connect()`. + + However, given a chain like `master -> ssh1 -> sudo1 -> ssh2 -> sudo2`, + `ssh1` requires an `ADD_ROUTE` for `ssh2`, and both `ssh1` and `sudo1` + require an `ADD_ROUTE` for `sudo2`, as neither directly dealt with its + establishment. + + Additional handles are created to receive the result of every function call triggered by :py:meth:`call_with_deadline() `. diff --git a/econtext/__init__.py b/econtext/__init__.py index 1eff1525..e25aca91 100644 --- a/econtext/__init__.py +++ b/econtext/__init__.py @@ -21,3 +21,8 @@ be expected. On the slave, it is built dynamically during startup. #: econtext.utils.run_with_broker(main) #: slave = False + + +#: This is ``0`` in a master, otherwise it is a master-generated ID unique to +#: the slave context. +context_id = 0 diff --git a/econtext/core.py b/econtext/core.py index ec9682e2..4407e470 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -25,18 +25,30 @@ import time import traceback import zlib +#import linetracer +#linetracer.start() + LOG = logging.getLogger('econtext') IOLOG = logging.getLogger('econtext.io') -IOLOG.setLevel(logging.INFO) +#IOLOG.setLevel(logging.INFO) GET_MODULE = 100 CALL_FUNCTION = 101 FORWARD_LOG = 102 +ADD_ROUTE = 103 + +CHUNK_SIZE = 16384 + -# When loaded as __main__, ensure classes and functions gain a __module__ -# attribute consistent with the host process, so that pickling succeeds. -__name__ = 'econtext.core' +if __name__ == 'econtext.core': + # When loaded using import mechanism, ExternalContext.main() will not have + # a chance to set the synthetic econtext global, so just import it here. + import econtext +else: + # When loaded as __main__, ensure classes and functions gain a __module__ + # attribute consistent with the host process, so that pickling succeeds. + __name__ = 'econtext.core' class Error(Exception): @@ -89,57 +101,94 @@ def set_cloexec(fd): fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) +class Message(object): + dst_id = None + src_id = None + handle = None + reply_to = None + data = None + + def __init__(self, **kwargs): + self.src_id = econtext.context_id + vars(self).update(kwargs) + + _find_global = None + + @classmethod + def pickled(cls, obj, **kwargs): + self = cls(**kwargs) + try: + self.data = cPickle.dumps(obj, protocol=2) + except cPickle.PicklingError, e: + self.data = cPickle.dumps(CallError(str(e)), protocol=2) + return self + + def unpickle(self): + """Deserialize `data` into an object.""" + IOLOG.debug('%r.unpickle()', self) + fp = cStringIO.StringIO(self.data) + unpickler = cPickle.Unpickler(fp) + if self._find_global: + unpickler.find_global = self._find_global + try: + return unpickler.load() + except (TypeError, ValueError), ex: + raise StreamError('invalid message: %s', ex) + + def __repr__(self): + return 'Message(%r, %r, %r, %r, %r..)' % ( + self.dst_id, self.src_id, self.handle, self.reply_to, + (self.data or '')[:50] + ) + + class Channel(object): def __init__(self, context, handle): self._context = context self._handle = handle self._queue = Queue.Queue() - self._context.add_handle_cb(self._receive, handle) + self._context.add_handler(self._receive, handle) - def _receive(self, reply_to, data): + def _receive(self, msg): """Callback from the Stream; appends data to the internal queue.""" - IOLOG.debug('%r._receive(%r)', self, data) - self._queue.put((reply_to, data)) + IOLOG.debug('%r._receive(%r)', self, msg) + self._queue.put(msg) def close(self): """Indicate this channel is closed to the remote side.""" IOLOG.debug('%r.close()', self) - self._context.enqueue(self._handle, _DEAD) + self._context.send(self._handle, _DEAD) - def send(self, data): + def put(self, data): """Send `data` to the remote.""" IOLOG.debug('%r.send(%r)', self, data) - self._context.enqueue(self._handle, data) + self._context.send(self._handle, data) - def receive(self, timeout=None): + def get(self, timeout=None): """Receive an object, or ``None`` if `timeout` is reached.""" IOLOG.debug('%r.on_receive(timeout=%r)', self, timeout) try: - reply_to, data = self._queue.get(True, timeout) + msg = self._queue.get(True, timeout) except Queue.Empty: return - IOLOG.debug('%r.on_receive() got %r', self, data) - - # Must occur off the broker thread. - if not isinstance(data, Dead): - try: - data = self._context.broker.unpickle(data) - except (TypeError, ValueError), ex: - raise StreamError('invalid message: %s', ex) + IOLOG.debug('%r.on_receive() got %r', self, msg) - if not isinstance(data, (Dead, CallError)): - return reply_to, data - elif data == _DEAD: + if msg == _DEAD: raise ChannelError('Channel is closed.') - else: + + # Must occur off the broker thread. + data = msg.unpickle() + if isinstance(data, CallError): raise data + return msg, data + def __iter__(self): """Yield objects from this channel until it is closed.""" while True: try: - yield self.receive() + yield self.get() except ChannelError: return @@ -154,20 +203,25 @@ class Importer(object): :param context: Context to communicate via. """ - def __init__(self, context): + def __init__(self, context, core_src): self._context = context self._present = {'econtext': [ 'econtext.ansible', 'econtext.compat', 'econtext.compat.pkgutil', 'econtext.master', - 'econtext.proxy', 'econtext.ssh', 'econtext.sudo', 'econtext.utils', ]} self.tls = threading.local() - self._cache = {} + self._cache = { + 'econtext.core': ( + None, + 'econtext/core.py', + zlib.compress(core_src), + ) + } def __repr__(self): return 'Importer()' @@ -205,8 +259,10 @@ class Importer(object): try: ret = self._cache[fullname] except KeyError: - self._cache[fullname] = ret = cPickle.loads( - self._context.enqueue_await_reply_raw(GET_MODULE, None, fullname) + self._cache[fullname] = ret = ( + self._context.send_await( + Message(data=fullname, handle=GET_MODULE) + ).unpickle() ) if ret is None: @@ -250,7 +306,7 @@ class LogHandler(logging.Handler): try: msg = self.format(rec) encoded = '%s\x00%s\x00%s' % (rec.name, rec.levelno, msg) - self.context.enqueue(FORWARD_LOG, encoded) + self.context.send(Message(data=encoded, handle=FORWARD_LOG)) finally: self.local.in_emit = False @@ -352,11 +408,19 @@ class Stream(BasicStream): """ _input_buf = '' _output_buf = '' + message_class = Message - def __init__(self, context): - self._context = context - self._rhmac = hmac.new(context.key, digestmod=sha) + def __init__(self, router, remote_id, key, **kwargs): + self._router = router + self.remote_id = remote_id + self.key = key + self._rhmac = hmac.new(key, digestmod=sha) self._whmac = self._rhmac.copy() + self.name = 'default' + self.construct(**kwargs) + + def construct(self): + pass def on_receive(self, broker): """Handle the next complete message on the stream. Raise @@ -364,7 +428,7 @@ class Stream(BasicStream): IOLOG.debug('%r.on_receive()', self) try: - buf = os.read(self.receive_side.fd, 4096) + buf = os.read(self.receive_side.fd, CHUNK_SIZE) IOLOG.debug('%r.on_receive() -> len %d', self, len(buf)) except OSError, e: IOLOG.debug('%r.on_receive() -> OSError: %s', self, e) @@ -373,39 +437,36 @@ class Stream(BasicStream): # sockets or pipes. Ideally this will be replaced later by a # 'goodbye' message to avoid reading from a disconnected endpoint, # allowing for more robust error reporting. - if e.errno != errno.EIO: + if e.errno not in (errno.EIO, errno.ECONNRESET): raise LOG.error('%r.on_receive(): %s', self, e) buf = '' self._input_buf += buf - while self._receive_one(): + while self._receive_one(broker): pass if not buf: return self.on_disconnect(broker) - HEADER_FMT = ( - '>' - '20s' # msg_mac - 'L' # handle - 'L' # reply_to - 'L' # msg_len - ) + HEADER_FMT = '>20sLLLLL' HEADER_LEN = struct.calcsize(HEADER_FMT) MAC_LEN = sha.digest_size - def _receive_one(self): + def _receive_one(self, broker): if len(self._input_buf) < self.HEADER_LEN: return False - msg_mac, handle, reply_to, msg_len = struct.unpack( + msg = Message() + (msg_mac, msg.dst_id, msg.src_id, + msg.handle, msg.reply_to, msg_len) = struct.unpack( self.HEADER_FMT, self._input_buf[:self.HEADER_LEN] ) if (len(self._input_buf) - self.HEADER_LEN) < msg_len: - IOLOG.debug('Input too short') + IOLOG.debug('%r: Input too short (want %d, got %d)', + self, msg_len, len(self._input_buf) - self.HEADER_LEN) return False self._rhmac.update(self._input_buf[ @@ -418,60 +479,34 @@ class Stream(BasicStream): expected_mac.encode('hex'), self._input_buf[24:msg_len+24]) - data = self._input_buf[self.HEADER_LEN:self.HEADER_LEN+msg_len] + msg.data = self._input_buf[self.HEADER_LEN:self.HEADER_LEN+msg_len] self._input_buf = self._input_buf[self.HEADER_LEN+msg_len:] - self._invoke(handle, reply_to, data) + self._router.route(msg) return True - def _invoke(self, handle, reply_to, data): - IOLOG.debug('%r._invoke(%r, %r, %r)', self, handle, reply_to, data) - try: - persist, fn = self._context._handle_map[handle] - except KeyError: - raise StreamError('%r: invalid handle: %r', self, handle) - - if not persist: - del self._context._handle_map[handle] - - try: - fn(reply_to, data) - except Exception: - LOG.debug('%r._invoke(%r, %r): %r crashed', self, handle, data, fn) - def on_transmit(self, broker): """Transmit buffered messages.""" IOLOG.debug('%r.on_transmit()', self) - written = os.write(self.transmit_side.fd, self._output_buf[:4096]) + written = os.write(self.transmit_side.fd, self._output_buf[:CHUNK_SIZE]) IOLOG.debug('%r.on_transmit() -> len %d', self, written) self._output_buf = self._output_buf[written:] if not self._output_buf: broker.stop_transmit(self) - def _enqueue(self, handle, data, reply_to): - IOLOG.debug('%r._enqueue(%r, %r)', self, handle, data) - msg = struct.pack('>LLL', handle, reply_to, len(data)) + data - self._whmac.update(msg) - self._output_buf += self._whmac.digest() + msg - self._context.broker.start_transmit(self) - - def enqueue_raw(self, handle, data, reply_to=0): - """Enqueue `data` to `handle`, and tell the broker we have output. May - be called from any thread.""" - self._context.broker.on_thread(self._enqueue, handle, data, reply_to) - - def enqueue(self, handle, obj, reply_to=0): - """Enqueue `obj` to `handle`, and tell the broker we have output. May + def send(self, msg): + """Send `data` to `handle`, and tell the broker we have output. May be called from any thread.""" - try: - encoded = cPickle.dumps(obj, protocol=2) - except cPickle.PicklingError, e: - encoded = cPickle.dumps(CallError(e), protocol=2) - self.enqueue_raw(handle, encoded, reply_to) + IOLOG.debug('%r._send(%r)', self, msg) + pkt = struct.pack('>LLLLL', msg.dst_id, msg.src_id, + msg.handle, msg.reply_to or 0, len(msg.data) + ) + msg.data + self._whmac.update(pkt) + self._output_buf += self._whmac.digest() + pkt + self._router.broker.start_transmit(self) def on_disconnect(self, broker): super(Stream, self).on_disconnect(broker) - if self._context.stream is self: - self._context.on_disconnect(broker) + self._router.on_disconnect(self, broker) def on_shutdown(self, broker): """Override BasicStream behaviour of immediately disconnecting.""" @@ -482,99 +517,95 @@ class Stream(BasicStream): self.transmit_side = Side(self, os.dup(wfd)) set_cloexec(self.receive_side.fd) set_cloexec(self.transmit_side.fd) - self._context.stream = self def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._context) + cls = type(self) + return '%s.%s(%r)' % (cls.__module__, cls.__name__, self.name) class Context(object): """ Represent a remote context regardless of connection method. """ - stream = None remote_name = None - def __init__(self, broker, name=None, hostname=None, username=None, - key=None, parent_addr=None): - self.broker = broker + def __init__(self, router, context_id, name=None, key=None): + self.router = router + self.context_id = context_id self.name = name - self.hostname = hostname - self.username = username self.key = key or ('%016x' % random.getrandbits(128)) - self.parent_addr = parent_addr - self._last_handle = itertools.count(1000) + #: handle -> (persistent?, func(msg)) self._handle_map = {} + self._last_handle = itertools.count(1000) + + def add_handler(self, fn, handle=None, persist=True): + """Invoke `fn(msg)` for each Message sent to `handle` from this + context. Unregister after one invocation if `persist` is ``False``. If + `handle` is ``None``, a new handle is allocated and returned.""" + handle = handle or self._last_handle.next() + IOLOG.debug('%r.add_handler(%r, %r, %r)', self, fn, handle, persist) + self._handle_map[handle] = persist, fn + return handle def on_shutdown(self, broker): """Called during :py:meth:`Broker.shutdown`, informs callbacks registered with :py:meth:`add_handle_cb` the connection is dead.""" LOG.debug('%r.on_shutdown(%r)', self, broker) for handle, (persist, fn) in self._handle_map.iteritems(): - LOG.debug('%r.on_disconnect(): killing %r: %r', self, handle, fn) + LOG.debug('%r.on_shutdown(): killing %r: %r', self, handle, fn) fn(0, _DEAD) def on_disconnect(self, broker): - self.stream = None LOG.debug('Parent stream is gone, dying.') broker.shutdown() - def alloc_handle(self): - """Allocate a handle.""" - return self._last_handle.next() - - def add_handle_cb(self, fn, handle, persist=True): - """Invoke `fn(obj)` for each `obj` sent to `handle`. Unregister after - one invocation if `persist` is ``False``.""" - IOLOG.debug('%r.add_handle_cb(%r, %r, %r)', self, fn, handle, persist) - self._handle_map[handle] = persist, fn - - def enqueue(self, handle, obj, reply_to=0): - if self.stream: - self.stream.enqueue(handle, obj, reply_to) + def send(self, msg): + """send `obj` to `handle`, and tell the broker we have output. May + be called from any thread.""" + msg.dst_id = self.context_id + msg.src_id = econtext.context_id + self.router.route(msg) - def enqueue_await_reply_raw(self, handle, deadline, data): - """Send `data` to `handle` and wait for a response with an optional - timeout.""" - if self.broker._thread == threading.currentThread(): # TODO + def send_await(self, msg, deadline=None): + """Send `msg` and wait for a response with an optional timeout.""" + if self.router.broker._thread == threading.currentThread(): # TODO raise SystemError('Cannot making blocking call on broker thread') - reply_to = self.alloc_handle() - LOG.debug('%r.enqueue_await_reply(%r, %r, %r) -> reply handle %d', - self, handle, deadline, data, reply_to) - queue = Queue.Queue() - receiver = lambda _reply_to, data: queue.put(data) - self.add_handle_cb(receiver, reply_to, persist=False) - self.stream.enqueue_raw(handle, data, reply_to) + msg.reply_to = self.add_handler(queue.put, persist=False) + LOG.debug('%r.send_await(%r)', self, msg) + self.send(msg) try: - data = queue.get(True, deadline) + msg = queue.get(True, deadline) except Queue.Empty: # self.broker.on_thread(self.stream.on_disconnect, self.broker) raise TimeoutError('deadline exceeded.') - if data == _DEAD: + if msg == _DEAD: raise StreamError('lost connection during call.') - IOLOG.debug('%r._enqueue_await_reply() -> %r', self, data) - return data + IOLOG.debug('%r._send_await() -> %r', self, msg) + return msg - def enqueue_await_reply(self, handle, deadline, obj): - """Like :py:meth:enqueue_await_reply_raw except that `data` is pickled - prior to sending, and the return value is unpickled on reception.""" - if self.broker._thread == threading.currentThread(): # TODO - raise SystemError('Cannot making blocking call on broker thread') + def _invoke(self, msg): + #IOLOG.debug('%r._invoke(%r)', self, msg) + try: + persist, fn = self._handle_map[msg.handle] + except KeyError: + LOG.error('%r: invalid handle: %r', self, msg) + return + + if not persist: + del self._handle_map[msg.handle] - encoded = cPickle.dumps(obj, protocol=2) - result = self.enqueue_await_reply_raw(handle, deadline, encoded) - decoded = self.broker.unpickle(result) - IOLOG.debug('%r.enqueue_await_reply() -> %r', self, decoded) - return decoded + try: + fn(msg) + except Exception: + LOG.exception('%r._invoke(%r): %r crashed', self, msg, fn) def __repr__(self): - bits = filter(None, (self.name, self.hostname, self.username)) - return 'Context(%s)' % ', '.join(map(repr, bits)) + return 'Context(%s, %r)' % (self.context_id, self.name) class Waker(BasicStream): @@ -651,7 +682,7 @@ class IoLogger(BasicStream): def on_receive(self, broker): LOG.debug('%r.on_receive()', self) - buf = os.read(self.receive_side.fd, 4096) + buf = os.read(self.receive_side.fd, CHUNK_SIZE) if not buf: return self.on_disconnect(broker) @@ -659,6 +690,75 @@ class IoLogger(BasicStream): self._log_lines() +class Router(object): + """ + Route messages between parent and child contexts, and invoke handlers + defined on our parent context. Router.route() straddles the Broker and user + threads, it is save to call from anywhere. + """ + parent_context = None + + def __init__(self, broker): + self.broker = broker + #: context ID -> Stream + self._stream_by_id = {} + #: List of contexts to notify of shutdown. + self._context_by_id = {} + + def __repr__(self): + return 'Router(%r)' % (self.broker,) + + def set_parent(self, context): + self.parent_context = context + context.add_handler(self._on_add_route, ADD_ROUTE) + + def on_disconnect(self, stream, broker): + """Invoked by Stream.on_disconnect().""" + if not self.parent_context: + return + + parent_stream = self._stream_by_id[self.parent_context.context_id] + if parent_stream is stream and self.parent_context: + self.parent_context.on_disconnect(broker) + + def on_shutdown(self): + for context in self._context_by_id.itervalues(): + context.on_shutdown() + + def add_route(self, target_id, via_id): + try: + self._stream_by_id[target_id] = self._stream_by_id[via_id] + except KeyError: + LOG.error('%r: cant add route to %r via %r: no such stream', + self, target_id, via_id) + + def _on_add_route(self, msg): + target_id, via_id = map(int, msg.data.split('\x00')) + self.add_route(target_id, via_id) + + def register(self, context, stream): + self._stream_by_id[context.context_id] = stream + self._context_by_id[context.context_id] = context + self.broker.start_receive(stream) + + def _route(self, msg): + #LOG.debug('%r._route(%r)', self, msg) + context = self._context_by_id.get(msg.src_id) + if msg.dst_id == econtext.context_id and context is not None: + context._invoke(msg) + return + + stream = self._stream_by_id.get(msg.dst_id) + if stream is None: + LOG.error('%r: no route for %r', self, msg) + return + + stream.send(msg) + + def route(self, msg): + self.broker.on_thread(self._route, msg) + + class Broker(object): """ Responsible for tracking contexts, their associated streams and I/O @@ -671,10 +771,10 @@ class Broker(object): #: gracefully before force-disconnecting them during :py:meth:`shutdown`. shutdown_timeout = 3.0 - def __init__(self): + def __init__(self, on_shutdown=[]): + self._on_shutdown = on_shutdown self._alive = True self._queue = Queue.Queue() - self._contexts = {} self._readers = set() self._writers = set() self._waker = Waker(self) @@ -682,17 +782,6 @@ class Broker(object): name='econtext-broker') self._thread.start() - _find_global = None - - def unpickle(self, data): - """Deserialize `data` into an object.""" - IOLOG.debug('%r.unpickle(%r)', self, data) - fp = cStringIO.StringIO(data) - unpickler = cPickle.Unpickler(fp) - if self._find_global: - unpickler.find_global = self._find_global - return unpickler.load() - def on_thread(self, func, *args, **kwargs): if threading.currentThread() == self._thread: func(*args, **kwargs) @@ -715,24 +804,13 @@ class Broker(object): def start_transmit(self, stream): IOLOG.debug('%r.start_transmit(%r)', self, stream) + assert stream.transmit_side self.on_thread(self._writers.add, stream.transmit_side) def stop_transmit(self, stream): IOLOG.debug('%r.stop_transmit(%r)', self, stream) self.on_thread(self._writers.discard, stream.transmit_side) - def register(self, context): - """Register `context` with this broker. Registration simply calls - :py:meth:`start_receive` on the context's :py:class:`Stream`, and records - a reference to it so that :py:meth:`Context.on_shutdown` can be - called during :py:meth:`shutdown`.""" - LOG.debug('%r.register(%r) -> r=%r w=%r', self, context, - context.stream.receive_side, - context.stream.transmit_side) - self.start_receive(context.stream) - self._contexts[context.name] = context - return context - def _call(self, stream, func): try: func(self) @@ -771,9 +849,7 @@ class Broker(object): attribute is ``True``, or any :py:class:`Context` is still registered that is not the master. Used to delay shutdown while some important work is in progress (e.g. log draining).""" - return sum(c.stream is not None and c.name != 'master' - for c in self._contexts.itervalues()) or \ - sum(side.keep_alive for side in self._readers) + return sum(side.keep_alive for side in self._readers) def _broker_main(self): """Handle events until :py:meth:`shutdown`. On shutdown, invoke @@ -798,9 +874,6 @@ class Broker(object): 'more child processes still connected to ' 'our stdout/stderr pipes.', self) - for context in self._contexts.itervalues(): - context.on_shutdown(self) - for side in self._readers | self._writers: LOG.error('_broker_main() force disconnecting %r', side) side.stream.on_disconnect(self) @@ -855,34 +928,47 @@ class ExternalContext(object): The :py:class:`IoLogger` connected to ``stderr``. """ - def _setup_master(self, key): + def _setup_master(self, parent_id, context_id, key): self.broker = Broker() - self.context = Context(self.broker, 'master', key=key) + self.router = Router(self.broker) + self.context = Context(self.router, parent_id, 'master', key=key) + self.router.set_parent(self.context) self.channel = Channel(self.context, CALL_FUNCTION) - self.context.stream = Stream(self.context) - self.context.stream.accept(100, 1) + self.stream = Stream(self.router, parent_id, key) + self.stream.accept(100, 1) os.wait() # Reap first stage. os.close(100) def _setup_logging(self, log_level): + return logging.basicConfig(level=log_level) root = logging.getLogger() root.setLevel(log_level) root.handlers = [LogHandler(self.context)] - LOG.debug('Connected to %s', self.context) def _setup_importer(self): - self.importer = Importer(self.context) + with os.fdopen(101, 'r', 1) as fp: + core_size = int(fp.readline()) + core_src = fp.read(core_size) + # Strip "ExternalContext.main()" call from last line. + core_src = '\n'.join(core_src.splitlines()[:-1]) + fp.close() + + self.importer = Importer(self.context, core_src) sys.meta_path.append(self.importer) - def _setup_package(self): + def _setup_package(self, context_id): + global econtext econtext = imp.new_module('econtext') econtext.__package__ = 'econtext' econtext.__path__ = [] econtext.__loader__ = self.importer econtext.slave = True + econtext.context_id = context_id econtext.core = sys.modules['__main__'] + econtext.core.__file__ = 'x/econtext/core.py' # For inspect.getsource() + econtext.core.__loader__ = self.importer sys.modules['econtext'] = econtext sys.modules['econtext.core'] = econtext.core del sys.modules['__main__'] @@ -900,9 +986,8 @@ class ExternalContext(object): fp.close() def _dispatch_calls(self): - for data in self.channel: - LOG.debug('_dispatch_calls(%r)', data) - reply_to, data = data + for msg, data in self.channel: + LOG.debug('_dispatch_calls(%r)', msg) with_context, modname, klass, func, args, kwargs = data if with_context: args = (self,) + args @@ -912,20 +997,23 @@ class ExternalContext(object): if klass: obj = getattr(obj, klass) fn = getattr(obj, func) - self.context.enqueue(reply_to, fn(*args, **kwargs)) + ret = fn(*args, **kwargs) + self.context.send(Message.pickled(ret, handle=msg.reply_to)) except Exception, e: - self.context.enqueue(reply_to, CallError(e)) + e = CallError(str(e)) + self.context.send(Message.pickled(e, handle=msg.reply_to)) - def main(self, key, log_level): - self._setup_master(key) + def main(self, parent_id, context_id, key, log_level): + self._setup_master(parent_id, context_id, key) try: try: self._setup_logging(log_level) self._setup_importer() - self._setup_package() + self._setup_package(context_id) self._setup_stdio() + LOG.debug('Connected to %s', self.context) - self.broker.register(self.context) + self.router.register(self.context, self.stream) self._dispatch_calls() LOG.debug('ExternalContext.main() normal exit') except BaseException: diff --git a/econtext/master.py b/econtext/master.py index 934f4919..3f872b94 100644 --- a/econtext/master.py +++ b/econtext/master.py @@ -4,6 +4,7 @@ starting new contexts via SSH. Its size is also restricted, since it must be sent to any context that will be used to establish additional child contexts. """ +import errno import getpass import imp import inspect @@ -94,11 +95,29 @@ def read_with_deadline(fd, size, deadline): raise econtext.core.TimeoutError('read timed out') + def iter_read(fd, deadline): + if deadline is not None: + LOG.error('Warning: iter_read(.., deadline=...) unimplemented') + + bits = [] while True: - s = os.read(fd, 4096) + try: + s = os.read(fd, 4096) + except OSError, e: + IOLOG.debug('iter_read(%r) -> OSError: %s', fd, e) + # See econtext.core.on_receive() EIO comment. + if e.errno != errno.EIO: + raise + s = '' + if not s: - raise econtext.core.StreamError('EOF on stream') + raise econtext.core.StreamError( + 'EOF on stream; last 100 bytes received: %r' % + (''.join(bits)[-100:],) + ) + + bits.append(s) yield s @@ -109,26 +128,30 @@ def discard_until(fd, s, deadline): class LogForwarder(object): + _log = None + def __init__(self, context): self._context = context - self._context.add_handle_cb(self.forward_log, - handle=econtext.core.FORWARD_LOG) - name = '%s.%s' % (RLOG.name, self._context.name) - self._log = logging.getLogger(name) - - def forward_log(self, reply_to, data): - if data == econtext.core._DEAD: - return + context.add_handler(self.forward, econtext.core.FORWARD_LOG) + + def forward(self, msg): + if not self._log: + # Delay initialization so Stream has a chance to set Context's + # default name, if one wasn't otherwise specified. + name = '%s.%s' % (RLOG.name, self._context.name) + self._log = logging.getLogger(name) + if msg != econtext.core._DEAD: + name, level_s, s = msg.data.split('\x00', 2) + self._log.log(int(level_s), '%s: %s', name, s) - name, level, s = data.split('\x00', 2) - self._log.log(level, '%s: %s', name, s) + def __repr__(self): + return 'LogForwarder(%r)' % (self._context,) class ModuleResponder(object): def __init__(self, context): self._context = context - self._context.add_handle_cb(self.get_module, - handle=econtext.core.GET_MODULE) + context.add_handler(self.get_module, econtext.core.GET_MODULE) def __repr__(self): return 'ModuleResponder(%r)' % (self._context,) @@ -189,11 +212,12 @@ class ModuleResponder(object): _get_module_via_sys_modules, _get_module_via_parent_enumeration] - def get_module(self, reply_to, fullname): - LOG.debug('%r.get_module(%r, %r)', self, reply_to, fullname) - if fullname == econtext.core._DEAD: + def get_module(self, msg): + LOG.debug('%r.get_module(%r)', self, msg) + if msg == econtext.core._DEAD: return + fullname = msg.data try: for method in self.get_module_methods: tup = method(self, fullname) @@ -215,24 +239,64 @@ class ModuleResponder(object): pkg_present = None compressed = zlib.compress(source) - reply = (pkg_present, path, compressed) - self._context.enqueue(reply_to, reply) + self._context.send( + econtext.core.Message.pickled( + (pkg_present, path, compressed), + handle=msg.reply_to + ) + ) except Exception: LOG.debug('While importing %r', fullname, exc_info=True) - self._context.enqueue(reply_to, None) + self._context.send(reply_to, None) + + +class Message(econtext.core.Message): + """ + Message subclass that controls unpickling. + """ + def _find_global(self, module_name, class_name): + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" + if (module_name, class_name) not in PERMITTED_CLASSES: + raise econtext.core.StreamError( + '%r attempted to unpickle %r in module %r', + self._context, class_name, module_name) + return getattr(sys.modules[module_name], class_name) class Stream(econtext.core.Stream): """ Base for streams capable of starting new slaves. """ + message_class = Message + #: The path to the remote Python interpreter. - python_path = sys.executable + python_path = 'python2.7' + + def construct(self, remote_name=None, python_path=None, **kwargs): + """Get the named context running on the local machine, creating it if + it does not exist.""" + super(Stream, self).construct(**kwargs) + if python_path: + self.python_path = python_path + + if remote_name is None: + remote_name = '%s@%s:%d' + remote_name %= (getpass.getuser(), socket.gethostname(), os.getpid()) + self.remote_name = remote_name + self.name = 'local.default' def on_shutdown(self, broker): """Request the slave gracefully shut itself down.""" LOG.debug('%r closing CALL_FUNCTION channel', self) - self.enqueue(econtext.core.CALL_FUNCTION, econtext.core._DEAD) + self.send( + econtext.core.Message.pickled( + econtext.core._DEAD, + src_id=econtext.context_id, + dst_id=self.remote_id, + handle=econtext.core.CALL_FUNCTION + ) + ) # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a # pipe, then execs a new interpreter with a custom argv. 'CONTEXT_NAME' is @@ -240,39 +304,36 @@ class Stream(econtext.core.Stream): def _first_stage(): import os,sys,zlib R,W=os.pipe() + R2,W2=os.pipe() if os.fork(): os.dup2(0,100) os.dup2(R,0) - os.close(R) - os.close(W) + os.dup2(R2,101) + for f in R,R2,W,W2: os.close(f) os.execv(sys.executable,['econtext:CONTEXT_NAME']) else: os.write(1, 'EC0\n') - os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) + C = zlib.decompress(sys.stdin.read(input())) + os.fdopen(W,'w',0).write(C) + os.fdopen(W2,'w',0).write('%s\n%s' % (len(C),C)) os.write(1, 'EC1\n') sys.exit(0) def get_boot_command(self): - name = self._context.remote_name - if name is None: - name = '%s@%s:%d' - name %= (getpass.getuser(), socket.gethostname(), os.getpid()) - source = inspect.getsource(self._first_stage) source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) source = source.replace(' ', '\t') - source = source.replace('CONTEXT_NAME', name) + source = source.replace('CONTEXT_NAME', self.remote_name) encoded = source.encode('base64').replace('\n', '') return [self.python_path, '-c', 'exec("%s".decode("base64"))' % (encoded,)] - def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, self._context) - def get_preamble(self): source = inspect.getsource(econtext.core) source += '\nExternalContext().main%r\n' % (( - self._context.key, + econtext.context_id, # parent_id + self.remote_id, # context_id + self.key, LOG.level or logging.getLogger().level or logging.INFO, ),) @@ -304,22 +365,6 @@ class Stream(econtext.core.Stream): class Broker(econtext.core.Broker): shutdown_timeout = 5.0 - def _find_global(self, module_name, class_name): - """Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted.""" - if (module_name, class_name) not in PERMITTED_CLASSES: - raise econtext.core.StreamError( - '%r attempted to unpickle %r in module %r', - self._context, class_name, module_name) - return getattr(sys.modules[module_name], class_name) - - def __enter__(self): - return self - - def __exit__(self, e_type, e_val, tb): - self.shutdown() - self.join() - class Context(econtext.core.Context): def __init__(self, *args, **kwargs): @@ -328,7 +373,7 @@ class Context(econtext.core.Context): self.log_forwarder = LogForwarder(self) def on_disconnect(self, broker): - self.stream = None + pass def call_with_deadline(self, deadline, with_context, fn, *args, **kwargs): """Invoke `fn([context,] *args, **kwargs)` in the external context. @@ -350,23 +395,68 @@ class Context(econtext.core.Context): klass = None call = (with_context, fn.__module__, klass, fn.__name__, args, kwargs) - result = self.enqueue_await_reply(econtext.core.CALL_FUNCTION, - deadline, call) - if isinstance(result, econtext.core.CallError): - raise result - return result + response = self.send_await( + econtext.core.Message.pickled( + call, + handle=econtext.core.CALL_FUNCTION + ), + deadline + ) + + decoded = response.unpickle() + if isinstance(decoded, econtext.core.CallError): + raise decoded + return decoded def call(self, fn, *args, **kwargs): """Invoke `fn(*args, **kwargs)` in the external context.""" return self.call_with_deadline(None, False, fn, *args, **kwargs) -def connect(broker, name='default', python_path=None): - """Get the named context running on the local machine, creating it if - it does not exist.""" - context = Context(broker, name) - context.stream = Stream(context) - if python_path: - context.stream.python_path = python_path - context.stream.connect() - return broker.register(context) +def _proxy_connect(econtext, name, context_id, klass, kwargs): + econtext.router.__class__ = Router # TODO + context = econtext.router._connect( + context_id, + klass, + name=name, + **kwargs + ) + return context.name + + +class Router(econtext.core.Router): + next_slave_id = 10 + + def alloc_slave_id(self): + """Allocate a context_id for a slave about to be created.""" + self.next_slave_id += 1 + return self.next_slave_id + + def __enter__(self): + return self + + def __exit__(self, e_type, e_val, tb): + self.broker.shutdown() + self.broker.join() + + def _connect(self, context_id, klass, name=None, **kwargs): + context = Context(self, context_id) + stream = klass(self, context.context_id, context.key, **kwargs) + context.name = name or stream.name + stream.connect() + self.register(context, stream) + return context + + def connect(self, klass, name=None, **kwargs): + context_id = self.alloc_slave_id() + return self._connect(context_id, klass, name=name, **kwargs) + + def proxy_connect(self, via_context, klass, name=None, **kwargs): + context_id = self.alloc_slave_id() + name = via_context.call_with_deadline(3.0, True, + _proxy_connect, name, context_id, klass, kwargs + ) + name = '%s.%s' % (via_context.name, name) + print ['got name:', name] + self.add_route(context_id, via.context_id) + return Context(self, context_id, name=name) diff --git a/econtext/ssh.py b/econtext/ssh.py index b4f10162..c58c9297 100644 --- a/econtext/ssh.py +++ b/econtext/ssh.py @@ -9,29 +9,25 @@ import econtext.master class Stream(econtext.master.Stream): python_path = 'python' + #: The path to the SSH binary. ssh_path = 'ssh' + def construct(self, hostname, username=None, ssh_path=None, **kwargs): + super(Stream, self).construct(**kwargs) + self.hostname = hostname + self.username = username + if ssh_path: + self.ssh_path = ssh_path + self.name = 'ssh.' + hostname + + def default_name(self): + return self.hostname + def get_boot_command(self): bits = [self.ssh_path] - if self._context.username: - bits += ['-l', self._context.username] - bits.append(self._context.hostname) + if self.username: + bits += ['-l', self.username] + bits.append(self.hostname) base = super(Stream, self).get_boot_command() return bits + map(commands.mkarg, base) - - -def connect(broker, hostname, username=None, name=None, - ssh_path=None, python_path=None): - """Get the named remote context, creating it if it does not exist.""" - if name is None: - name = hostname - - context = econtext.master.Context(broker, name, hostname, username) - context.stream = Stream(context) - if python_path: - context.stream.python_path = python_path - if ssh_path: - context.stream.ssh_path = ssh_path - context.stream.connect() - return broker.register(context) diff --git a/econtext/sudo.py b/econtext/sudo.py index 0e5dfe02..45b13635 100644 --- a/econtext/sudo.py +++ b/econtext/sudo.py @@ -94,9 +94,44 @@ class Stream(econtext.master.Stream): sudo_path = 'sudo' password = None + def construct(self, username=None, sudo_path=None, password=None, **kwargs): + """ + Get the named sudo context, creating it if it does not exist. + + :param econtext.core.Broker broker: + The broker that will own the context. + + :param str username: + Username to pass to sudo as the ``-u`` parameter, defaults to ``root``. + + :param str sudo_path: + Filename or complete path to the sudo binary. ``PATH`` will be searched + if given as a filename. Defaults to ``sudo``. + + :param str python_path: + Filename or complete path to the Python binary. ``PATH`` will be + searched if given as a filename. Defaults to :py:data:`sys.executable`. + + :param str password: + The password to use when authenticating to sudo. Depending on the sudo + configuration, this is either the current account password or the + target account password. :py:class:`econtext.sudo.PasswordError` will + be raised if sudo requests a password but none is provided. + + """ + super(Stream, self).construct(**kwargs) + self.username = username or 'root' + if sudo_path: + self.sudo_path = sudo_path + if password: + self.password = password + self.name = 'sudo.' + self.username + def get_boot_command(self): - bits = [self.sudo_path, '-u', self._context.username] - return bits + super(Stream, self).get_boot_command() + bits = [self.sudo_path, '-u', self.username] + bits = bits + super(Stream, self).get_boot_command() + LOG.debug('sudo command line: %r', bits) + return bits password_incorrect_msg = 'sudo password is incorrect' password_required_msg = 'sudo password is required' @@ -118,47 +153,3 @@ class Stream(econtext.master.Stream): password_sent = True else: raise econtext.core.StreamError('bootstrap failed') - - -def connect(broker, username=None, sudo_path=None, python_path=None, password=None): - """ - Get the named sudo context, creating it if it does not exist. - - :param econtext.core.Broker broker: - The broker that will own the context. - - :param str username: - Username to pass to sudo as the ``-u`` parameter, defaults to ``root``. - - :param str sudo_path: - Filename or complete path to the sudo binary. ``PATH`` will be searched - if given as a filename. Defaults to ``sudo``. - - :param str python_path: - Filename or complete path to the Python binary. ``PATH`` will be - searched if given as a filename. Defaults to :py:data:`sys.executable`. - - :param str password: - The password to use when authenticating to sudo. Depending on the sudo - configuration, this is either the current account password or the - target account password. :py:class:`econtext.sudo.PasswordError` will - be raised if sudo requests a password but none is provided. - - """ - if username is None: - username = 'root' - - context = econtext.master.Context( - broker=broker, - name='sudo.' + username, - username=username) - - context.stream = Stream(context) - if sudo_path: - context.stream.sudo_path = sudo_path - if password: - context.stream.password = password - if python_path: - context.stream.python_path = python_path - context.stream.connect() - return broker.register(context) diff --git a/econtext/utils.py b/econtext/utils.py index 093a1447..1fd4c5b7 100644 --- a/econtext/utils.py +++ b/econtext/utils.py @@ -22,6 +22,11 @@ def disable_site_packages(): sys.path.remove(entry) +def log_to_tmp(): + import os + log_to_file(path='/tmp/econtext.%s.log' % (os.getpid(),)) + + def log_to_file(path=None, io=True, level=logging.DEBUG): """Install a new :py:class:`logging.Handler` writing applications logs to the filesystem. Useful when debugging slave IO problems.""" @@ -43,19 +48,20 @@ def log_to_file(path=None, io=True, level=logging.DEBUG): log.handlers.insert(0, handler) -def run_with_broker(func, *args, **kwargs): +def run_with_router(func, *args, **kwargs): """Arrange for `func(broker, *args, **kwargs)` to run with a temporary - :py:class:`econtext.master.Broker`, ensuring the broker is correctly - shut down during normal or exceptional return.""" + :py:class:`econtext.master.Router`, ensuring the Router and Broker are + correctly shut down during normal or exceptional return.""" broker = econtext.master.Broker() + router = econtext.master.Router(broker) try: - return func(broker, *args, **kwargs) + return func(router, *args, **kwargs) finally: broker.shutdown() broker.join() -def with_broker(func): +def with_router(func): """Decorator version of :py:func:`run_with_broker`. Example: .. code-block:: python @@ -67,6 +73,6 @@ def with_broker(func): do_stuff(blah, 123) """ def wrapper(*args, **kwargs): - return run_with_broker(func, *args, **kwargs) + return run_with_router(func, *args, **kwargs) wrapper.func_name = func.func_name return wrapper diff --git a/preamble_size.py b/preamble_size.py index 0efba5c9..99add0fe 100644 --- a/preamble_size.py +++ b/preamble_size.py @@ -9,21 +9,21 @@ import zlib import econtext.master import econtext.ssh import econtext.sudo -import econtext.proxy -context = econtext.master.Context(None, name='default', hostname='default') -stream = econtext.ssh.Stream(context) -print 'SSH command size: %s' % (len(' '.join(stream.get_boot_command())),) -print 'Preamble size: %s (%.2fKiB)' % ( - len(stream.get_preamble()), - len(stream.get_preamble()) / 1024.0, -) +with econtext.master.Broker() as broker: + router = econtext.core.Router(broker) + context = econtext.master.Context(router, 0) + stream = econtext.ssh.Stream(router, 0, context.key, hostname='foo') + print 'SSH command size: %s' % (len(' '.join(stream.get_boot_command())),) + print 'Preamble size: %s (%.2fKiB)' % ( + len(stream.get_preamble()), + len(stream.get_preamble()) / 1024.0, + ) for mod in ( econtext.master, econtext.ssh, econtext.sudo, - econtext.proxy ): sz = len(zlib.compress(econtext.master.minimize_source(inspect.getsource(mod)))) print '%s size: %s (%.2fKiB)' % (mod.__name__, sz, sz / 1024.0)