diff --git a/docs/howitworks.rst b/docs/howitworks.rst index 7420e4bc..fe812e41 100644 --- a/docs/howitworks.rst +++ b/docs/howitworks.rst @@ -231,19 +231,19 @@ Stream Protocol Once connected, a basic framing protocol is used to communicate between master and slave: -+------------+-------+-----------------------------------------------------+ -| Field | Size | Description | -+============+=======+=====================================================+ -| ``hmac`` | 20 | SHA-1 MAC over (``length || data``) | -+------------+-------+-----------------------------------------------------+ -| ``length`` | 4 | Message length | -+------------+-------+-----------------------------------------------------+ -| ``data`` | n/a | Pickled message data. | -+------------+-------+-----------------------------------------------------+ - -The ``data`` component always consists of a 2-tuple, `(handle, data)`, where -``handle`` is an integer describing the message target and ``data`` is the -value to be delivered to the target. ++----------------+------+------------------------------------------------------+ +| Field | Size | Description | ++================+======+======================================================+ +| ``hmac`` | 20 | SHA-1 over (``context_id || handle || len || data``) | ++----------------+------+------------------------------------------------------+ +| ``handle`` | 4 | Integer target handle in recipient. | ++----------------+------+------------------------------------------------------+ +| ``reply_to`` | 4 | Integer response target ID. | ++----------------+------+------------------------------------------------------+ +| ``length`` | 4 | Message length | ++----------------+------+------------------------------------------------------+ +| ``data`` | n/a | Pickled message data. | ++----------------+------+------------------------------------------------------+ Masters listen on the following handles: @@ -304,6 +304,14 @@ cannot be used securely, however few of those accounts appear to be expert, and none mention any additional attacks that would not be prevented by using a restrictive class whitelist. +.. note:: + + Since unpickling may trigger module loads, it is not possible to + deserialize data on the broker thread, as this will result in recursion + leading to a deadlock. Therefore any internal services (module loader, + logging forwarder, etc.) must rely on simple string formats, or only + perform serialization from within the broker thread. + Use of HMAC ########### diff --git a/econtext/core.py b/econtext/core.py index 11335b54..ec9682e2 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -96,10 +96,10 @@ class Channel(object): self._queue = Queue.Queue() self._context.add_handle_cb(self._receive, handle) - def _receive(self, data): + def _receive(self, reply_to, data): """Callback from the Stream; appends data to the internal queue.""" IOLOG.debug('%r._receive(%r)', self, data) - self._queue.put(data) + self._queue.put((reply_to, data)) def close(self): """Indicate this channel is closed to the remote side.""" @@ -115,13 +115,21 @@ class Channel(object): """Receive an object, or ``None`` if `timeout` is reached.""" IOLOG.debug('%r.on_receive(timeout=%r)', self, timeout) try: - data = self._queue.get(True, timeout) + reply_to, data = self._queue.get(True, timeout) except Queue.Empty: return IOLOG.debug('%r.on_receive() got %r', self, data) + + # Must occur off the broker thread. + if not isinstance(data, Dead): + try: + data = self._context.broker.unpickle(data) + except (TypeError, ValueError), ex: + raise StreamError('invalid message: %s', ex) + if not isinstance(data, (Dead, CallError)): - return data + return reply_to, data elif data == _DEAD: raise ChannelError('Channel is closed.') else: @@ -153,6 +161,7 @@ class Importer(object): 'econtext.compat', 'econtext.compat.pkgutil', 'econtext.master', + 'econtext.proxy', 'econtext.ssh', 'econtext.sudo', 'econtext.utils', @@ -168,6 +177,7 @@ class Importer(object): return None self.tls.running = True + fullname = fullname.rstrip('.') try: pkgname, _, _ = fullname.rpartition('.') LOG.debug('%r.find_module(%r)', self, fullname) @@ -195,8 +205,9 @@ class Importer(object): try: ret = self._cache[fullname] except KeyError: - ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,)) - self._cache[fullname] = ret + self._cache[fullname] = ret = cPickle.loads( + self._context.enqueue_await_reply_raw(GET_MODULE, None, fullname) + ) if ret is None: raise ImportError('Master does not have %r' % (fullname,)) @@ -238,7 +249,8 @@ class LogHandler(logging.Handler): self.local.in_emit = True try: msg = self.format(rec) - self.context.enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg)) + encoded = '%s\x00%s\x00%s' % (rec.name, rec.levelno, msg) + self.context.enqueue(FORWARD_LOG, encoded) finally: self.local.in_emit = False @@ -346,17 +358,6 @@ class Stream(BasicStream): self._rhmac = hmac.new(context.key, digestmod=sha) self._whmac = self._rhmac.copy() - _find_global = None - - def unpickle(self, data): - """Deserialize `data` into an object.""" - IOLOG.debug('%r.unpickle(%r)', self, data) - fp = cStringIO.StringIO(data) - unpickler = cPickle.Unpickler(fp) - if self._find_global: - unpickler.find_global = self._find_global - return unpickler.load() - def on_receive(self, broker): """Handle the next complete message on the stream. Raise :py:class:`StreamError` on failure.""" @@ -364,7 +365,9 @@ class Stream(BasicStream): try: buf = os.read(self.receive_side.fd, 4096) + IOLOG.debug('%r.on_receive() -> len %d', self, len(buf)) except OSError, e: + IOLOG.debug('%r.on_receive() -> OSError: %s', self, e) # When connected over a TTY (i.e. sudo), disconnection of the # remote end is signalled by EIO, rather than an empty read like # sockets or pipes. Ideally this will be replaced later by a @@ -382,17 +385,32 @@ class Stream(BasicStream): if not buf: return self.on_disconnect(broker) + HEADER_FMT = ( + '>' + '20s' # msg_mac + 'L' # handle + 'L' # reply_to + 'L' # msg_len + ) + HEADER_LEN = struct.calcsize(HEADER_FMT) + MAC_LEN = sha.digest_size + def _receive_one(self): - if len(self._input_buf) < 24: + if len(self._input_buf) < self.HEADER_LEN: return False - msg_mac = self._input_buf[:20] - msg_len = struct.unpack('>L', self._input_buf[20:24])[0] - if len(self._input_buf)-24 < msg_len: + msg_mac, handle, reply_to, msg_len = struct.unpack( + self.HEADER_FMT, + self._input_buf[:self.HEADER_LEN] + ) + + if (len(self._input_buf) - self.HEADER_LEN) < msg_len: IOLOG.debug('Input too short') return False - self._rhmac.update(self._input_buf[20:msg_len+24]) + self._rhmac.update(self._input_buf[ + self.MAC_LEN : (msg_len + self.HEADER_LEN) + ]) expected_mac = self._rhmac.digest() if msg_mac != expected_mac: raise StreamError('bad MAC: %r != got %r; %r', @@ -400,17 +418,13 @@ class Stream(BasicStream): expected_mac.encode('hex'), self._input_buf[24:msg_len+24]) - try: - handle, data = self.unpickle(self._input_buf[24:msg_len+24]) - except (TypeError, ValueError), ex: - raise StreamError('invalid message: %s', ex) - - self._input_buf = self._input_buf[msg_len+24:] - self._invoke(handle, data) + data = self._input_buf[self.HEADER_LEN:self.HEADER_LEN+msg_len] + self._input_buf = self._input_buf[self.HEADER_LEN+msg_len:] + self._invoke(handle, reply_to, data) return True - def _invoke(self, handle, data): - IOLOG.debug('%r._invoke(%r, %r)', self, handle, data) + def _invoke(self, handle, reply_to, data): + IOLOG.debug('%r._invoke(%r, %r, %r)', self, handle, reply_to, data) try: persist, fn = self._context._handle_map[handle] except KeyError: @@ -420,7 +434,7 @@ class Stream(BasicStream): del self._context._handle_map[handle] try: - fn(data) + fn(reply_to, data) except Exception: LOG.debug('%r._invoke(%r, %r): %r crashed', self, handle, data, fn) @@ -428,26 +442,31 @@ class Stream(BasicStream): """Transmit buffered messages.""" IOLOG.debug('%r.on_transmit()', self) written = os.write(self.transmit_side.fd, self._output_buf[:4096]) + IOLOG.debug('%r.on_transmit() -> len %d', self, written) self._output_buf = self._output_buf[written:] if not self._output_buf: broker.stop_transmit(self) - def _enqueue(self, handle, obj): - IOLOG.debug('%r._enqueue(%r, %r)', self, handle, obj) - try: - encoded = cPickle.dumps((handle, obj), protocol=2) - except cPickle.PicklingError, e: - encoded = cPickle.dumps((handle, CallError(e)), protocol=2) - - msg = struct.pack('>L', len(encoded)) + encoded + def _enqueue(self, handle, data, reply_to): + IOLOG.debug('%r._enqueue(%r, %r)', self, handle, data) + msg = struct.pack('>LLL', handle, reply_to, len(data)) + data self._whmac.update(msg) self._output_buf += self._whmac.digest() + msg self._context.broker.start_transmit(self) - def enqueue(self, handle, obj): + def enqueue_raw(self, handle, data, reply_to=0): + """Enqueue `data` to `handle`, and tell the broker we have output. May + be called from any thread.""" + self._context.broker.on_thread(self._enqueue, handle, data, reply_to) + + def enqueue(self, handle, obj, reply_to=0): """Enqueue `obj` to `handle`, and tell the broker we have output. May be called from any thread.""" - self._context.broker.on_thread(self._enqueue, handle, obj) + try: + encoded = cPickle.dumps(obj, protocol=2) + except cPickle.PicklingError, e: + encoded = cPickle.dumps(CallError(e), protocol=2) + self.enqueue_raw(handle, encoded, reply_to) def on_disconnect(self, broker): super(Stream, self).on_disconnect(broker) @@ -493,7 +512,7 @@ class Context(object): LOG.debug('%r.on_shutdown(%r)', self, broker) for handle, (persist, fn) in self._handle_map.iteritems(): LOG.debug('%r.on_disconnect(): killing %r: %r', self, handle, fn) - fn(_DEAD) + fn(0, _DEAD) def on_disconnect(self, broker): self.stream = None @@ -510,34 +529,49 @@ class Context(object): IOLOG.debug('%r.add_handle_cb(%r, %r, %r)', self, fn, handle, persist) self._handle_map[handle] = persist, fn - def enqueue(self, handle, obj): + def enqueue(self, handle, obj, reply_to=0): if self.stream: - self.stream.enqueue(handle, obj) + self.stream.enqueue(handle, obj, reply_to) - def enqueue_await_reply(self, handle, deadline, data): + def enqueue_await_reply_raw(self, handle, deadline, data): """Send `data` to `handle` and wait for a response with an optional - timeout. The message contains `(reply_to, data)`, where `reply_to` is - the handle on which this function expects its reply.""" + timeout.""" + if self.broker._thread == threading.currentThread(): # TODO + raise SystemError('Cannot making blocking call on broker thread') + reply_to = self.alloc_handle() LOG.debug('%r.enqueue_await_reply(%r, %r, %r) -> reply handle %d', self, handle, deadline, data, reply_to) queue = Queue.Queue() - self.add_handle_cb(queue.put, reply_to, persist=False) - self.stream.enqueue(handle, (reply_to,) + data) + receiver = lambda _reply_to, data: queue.put(data) + self.add_handle_cb(receiver, reply_to, persist=False) + self.stream.enqueue_raw(handle, data, reply_to) try: data = queue.get(True, deadline) except Queue.Empty: - self.broker.on_thread(self.stream.on_disconnect, self.broker) + # self.broker.on_thread(self.stream.on_disconnect, self.broker) raise TimeoutError('deadline exceeded.') if data == _DEAD: raise StreamError('lost connection during call.') - IOLOG.debug('%r._enqueue_await_reply(): got reply: %r', self, data) + IOLOG.debug('%r._enqueue_await_reply() -> %r', self, data) return data + def enqueue_await_reply(self, handle, deadline, obj): + """Like :py:meth:enqueue_await_reply_raw except that `data` is pickled + prior to sending, and the return value is unpickled on reception.""" + if self.broker._thread == threading.currentThread(): # TODO + raise SystemError('Cannot making blocking call on broker thread') + + encoded = cPickle.dumps(obj, protocol=2) + result = self.enqueue_await_reply_raw(handle, deadline, encoded) + decoded = self.broker.unpickle(result) + IOLOG.debug('%r.enqueue_await_reply() -> %r', self, decoded) + return decoded + def __repr__(self): bits = filter(None, (self.name, self.hostname, self.username)) return 'Context(%s)' % ', '.join(map(repr, bits)) @@ -648,6 +682,17 @@ class Broker(object): name='econtext-broker') self._thread.start() + _find_global = None + + def unpickle(self, data): + """Deserialize `data` into an object.""" + IOLOG.debug('%r.unpickle(%r)', self, data) + fp = cStringIO.StringIO(data) + unpickler = cPickle.Unpickler(fp) + if self._find_global: + unpickler.find_global = self._find_global + return unpickler.load() + def on_thread(self, func, *args, **kwargs): if threading.currentThread() == self._thread: func(*args, **kwargs) @@ -751,7 +796,7 @@ class Broker(object): LOG.error('%r: some streams did not close gracefully. ' 'The most likely cause for this is one or ' 'more child processes still connected to ' - 'ou stdout/stderr pipes.', self) + 'our stdout/stderr pipes.', self) for context in self._contexts.itervalues(): context.on_shutdown(self) @@ -857,7 +902,8 @@ class ExternalContext(object): def _dispatch_calls(self): for data in self.channel: LOG.debug('_dispatch_calls(%r)', data) - reply_to, with_context, modname, klass, func, args, kwargs = data + reply_to, data = data + with_context, modname, klass, func, args, kwargs = data if with_context: args = (self,) + args diff --git a/econtext/master.py b/econtext/master.py index f9942df6..934f4919 100644 --- a/econtext/master.py +++ b/econtext/master.py @@ -35,6 +35,11 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S) COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M) IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M) +PERMITTED_CLASSES = set([ + ('econtext.core', 'CallError'), + ('econtext.core', 'Dead'), +]) + def minimize_source(source): """Remove comments and docstrings from Python `source`, preserving line @@ -111,11 +116,11 @@ class LogForwarder(object): name = '%s.%s' % (RLOG.name, self._context.name) self._log = logging.getLogger(name) - def forward_log(self, data): + def forward_log(self, reply_to, data): if data == econtext.core._DEAD: return - name, level, s = data + name, level, s = data.split('\x00', 2) self._log.log(level, '%s: %s', name, s) @@ -184,12 +189,11 @@ class ModuleResponder(object): _get_module_via_sys_modules, _get_module_via_parent_enumeration] - def get_module(self, data): - LOG.debug('%r.get_module(%r)', self, data) - if data == econtext.core._DEAD: + def get_module(self, reply_to, fullname): + LOG.debug('%r.get_module(%r, %r)', self, reply_to, fullname) + if fullname == econtext.core._DEAD: return - reply_to, fullname = data try: for method in self.get_module_methods: tup = method(self, fullname) @@ -225,31 +229,11 @@ class Stream(econtext.core.Stream): #: The path to the remote Python interpreter. python_path = sys.executable - def __init__(self, context): - super(Stream, self).__init__(context) - self._permitted_classes = set([ - ('econtext.core', 'CallError'), - ('econtext.core', 'Dead'), - ]) - def on_shutdown(self, broker): """Request the slave gracefully shut itself down.""" LOG.debug('%r closing CALL_FUNCTION channel', self) self.enqueue(econtext.core.CALL_FUNCTION, econtext.core._DEAD) - def _find_global(self, module_name, class_name): - """Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted.""" - if (module_name, class_name) not in self._permitted_classes: - raise econtext.core.StreamError( - '%r attempted to unpickle %r in module %r', - self._context, class_name, module_name) - return getattr(sys.modules[module_name], class_name) - - def allow_class(self, module_name, class_name): - """Add `module_name` to the list of permitted modules.""" - self._permitted_modules.add((module_name, class_name)) - # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a # pipe, then execs a new interpreter with a custom argv. 'CONTEXT_NAME' is # replaced with the context name. Optimized for size. @@ -320,6 +304,15 @@ class Stream(econtext.core.Stream): class Broker(econtext.core.Broker): shutdown_timeout = 5.0 + def _find_global(self, module_name, class_name): + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" + if (module_name, class_name) not in PERMITTED_CLASSES: + raise econtext.core.StreamError( + '%r attempted to unpickle %r in module %r', + self._context, class_name, module_name) + return getattr(sys.modules[module_name], class_name) + def __enter__(self): return self