From e374f85888d4cf36b778846ae961407b997cef8e Mon Sep 17 00:00:00 2001 From: David Wilson Date: Sun, 20 Aug 2017 20:50:03 +0530 Subject: [PATCH] Fix module loader deadlock Stop using cPickle on the broker thread where it is not known whether the pickle data would cause the import machinery to be invoked, which currently relies on blocking calls. Huge mess but it works. This is due to: context.call(some.module.func, another.module.func) We stringify ("some.module", "func"), but the reference to another.module.func is passed into the pickle machinery, and there's no way to generically stringify all function references in user data for reification on the main thread, without doing something like this instead. --- docs/howitworks.rst | 34 ++++++---- econtext/core.py | 156 ++++++++++++++++++++++++++++---------------- econtext/master.py | 45 ++++++------- 3 files changed, 141 insertions(+), 94 deletions(-) diff --git a/docs/howitworks.rst b/docs/howitworks.rst index 7420e4bc..fe812e41 100644 --- a/docs/howitworks.rst +++ b/docs/howitworks.rst @@ -231,19 +231,19 @@ Stream Protocol Once connected, a basic framing protocol is used to communicate between master and slave: -+------------+-------+-----------------------------------------------------+ -| Field | Size | Description | -+============+=======+=====================================================+ -| ``hmac`` | 20 | SHA-1 MAC over (``length || data``) | -+------------+-------+-----------------------------------------------------+ -| ``length`` | 4 | Message length | -+------------+-------+-----------------------------------------------------+ -| ``data`` | n/a | Pickled message data. | -+------------+-------+-----------------------------------------------------+ - -The ``data`` component always consists of a 2-tuple, `(handle, data)`, where -``handle`` is an integer describing the message target and ``data`` is the -value to be delivered to the target. ++----------------+------+------------------------------------------------------+ +| Field | Size | Description | ++================+======+======================================================+ +| ``hmac`` | 20 | SHA-1 over (``context_id || handle || len || data``) | ++----------------+------+------------------------------------------------------+ +| ``handle`` | 4 | Integer target handle in recipient. | ++----------------+------+------------------------------------------------------+ +| ``reply_to`` | 4 | Integer response target ID. | ++----------------+------+------------------------------------------------------+ +| ``length`` | 4 | Message length | ++----------------+------+------------------------------------------------------+ +| ``data`` | n/a | Pickled message data. | ++----------------+------+------------------------------------------------------+ Masters listen on the following handles: @@ -304,6 +304,14 @@ cannot be used securely, however few of those accounts appear to be expert, and none mention any additional attacks that would not be prevented by using a restrictive class whitelist. +.. note:: + + Since unpickling may trigger module loads, it is not possible to + deserialize data on the broker thread, as this will result in recursion + leading to a deadlock. Therefore any internal services (module loader, + logging forwarder, etc.) must rely on simple string formats, or only + perform serialization from within the broker thread. + Use of HMAC ########### diff --git a/econtext/core.py b/econtext/core.py index 11335b54..ec9682e2 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -96,10 +96,10 @@ class Channel(object): self._queue = Queue.Queue() self._context.add_handle_cb(self._receive, handle) - def _receive(self, data): + def _receive(self, reply_to, data): """Callback from the Stream; appends data to the internal queue.""" IOLOG.debug('%r._receive(%r)', self, data) - self._queue.put(data) + self._queue.put((reply_to, data)) def close(self): """Indicate this channel is closed to the remote side.""" @@ -115,13 +115,21 @@ class Channel(object): """Receive an object, or ``None`` if `timeout` is reached.""" IOLOG.debug('%r.on_receive(timeout=%r)', self, timeout) try: - data = self._queue.get(True, timeout) + reply_to, data = self._queue.get(True, timeout) except Queue.Empty: return IOLOG.debug('%r.on_receive() got %r', self, data) + + # Must occur off the broker thread. + if not isinstance(data, Dead): + try: + data = self._context.broker.unpickle(data) + except (TypeError, ValueError), ex: + raise StreamError('invalid message: %s', ex) + if not isinstance(data, (Dead, CallError)): - return data + return reply_to, data elif data == _DEAD: raise ChannelError('Channel is closed.') else: @@ -153,6 +161,7 @@ class Importer(object): 'econtext.compat', 'econtext.compat.pkgutil', 'econtext.master', + 'econtext.proxy', 'econtext.ssh', 'econtext.sudo', 'econtext.utils', @@ -168,6 +177,7 @@ class Importer(object): return None self.tls.running = True + fullname = fullname.rstrip('.') try: pkgname, _, _ = fullname.rpartition('.') LOG.debug('%r.find_module(%r)', self, fullname) @@ -195,8 +205,9 @@ class Importer(object): try: ret = self._cache[fullname] except KeyError: - ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,)) - self._cache[fullname] = ret + self._cache[fullname] = ret = cPickle.loads( + self._context.enqueue_await_reply_raw(GET_MODULE, None, fullname) + ) if ret is None: raise ImportError('Master does not have %r' % (fullname,)) @@ -238,7 +249,8 @@ class LogHandler(logging.Handler): self.local.in_emit = True try: msg = self.format(rec) - self.context.enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg)) + encoded = '%s\x00%s\x00%s' % (rec.name, rec.levelno, msg) + self.context.enqueue(FORWARD_LOG, encoded) finally: self.local.in_emit = False @@ -346,17 +358,6 @@ class Stream(BasicStream): self._rhmac = hmac.new(context.key, digestmod=sha) self._whmac = self._rhmac.copy() - _find_global = None - - def unpickle(self, data): - """Deserialize `data` into an object.""" - IOLOG.debug('%r.unpickle(%r)', self, data) - fp = cStringIO.StringIO(data) - unpickler = cPickle.Unpickler(fp) - if self._find_global: - unpickler.find_global = self._find_global - return unpickler.load() - def on_receive(self, broker): """Handle the next complete message on the stream. Raise :py:class:`StreamError` on failure.""" @@ -364,7 +365,9 @@ class Stream(BasicStream): try: buf = os.read(self.receive_side.fd, 4096) + IOLOG.debug('%r.on_receive() -> len %d', self, len(buf)) except OSError, e: + IOLOG.debug('%r.on_receive() -> OSError: %s', self, e) # When connected over a TTY (i.e. sudo), disconnection of the # remote end is signalled by EIO, rather than an empty read like # sockets or pipes. Ideally this will be replaced later by a @@ -382,17 +385,32 @@ class Stream(BasicStream): if not buf: return self.on_disconnect(broker) + HEADER_FMT = ( + '>' + '20s' # msg_mac + 'L' # handle + 'L' # reply_to + 'L' # msg_len + ) + HEADER_LEN = struct.calcsize(HEADER_FMT) + MAC_LEN = sha.digest_size + def _receive_one(self): - if len(self._input_buf) < 24: + if len(self._input_buf) < self.HEADER_LEN: return False - msg_mac = self._input_buf[:20] - msg_len = struct.unpack('>L', self._input_buf[20:24])[0] - if len(self._input_buf)-24 < msg_len: + msg_mac, handle, reply_to, msg_len = struct.unpack( + self.HEADER_FMT, + self._input_buf[:self.HEADER_LEN] + ) + + if (len(self._input_buf) - self.HEADER_LEN) < msg_len: IOLOG.debug('Input too short') return False - self._rhmac.update(self._input_buf[20:msg_len+24]) + self._rhmac.update(self._input_buf[ + self.MAC_LEN : (msg_len + self.HEADER_LEN) + ]) expected_mac = self._rhmac.digest() if msg_mac != expected_mac: raise StreamError('bad MAC: %r != got %r; %r', @@ -400,17 +418,13 @@ class Stream(BasicStream): expected_mac.encode('hex'), self._input_buf[24:msg_len+24]) - try: - handle, data = self.unpickle(self._input_buf[24:msg_len+24]) - except (TypeError, ValueError), ex: - raise StreamError('invalid message: %s', ex) - - self._input_buf = self._input_buf[msg_len+24:] - self._invoke(handle, data) + data = self._input_buf[self.HEADER_LEN:self.HEADER_LEN+msg_len] + self._input_buf = self._input_buf[self.HEADER_LEN+msg_len:] + self._invoke(handle, reply_to, data) return True - def _invoke(self, handle, data): - IOLOG.debug('%r._invoke(%r, %r)', self, handle, data) + def _invoke(self, handle, reply_to, data): + IOLOG.debug('%r._invoke(%r, %r, %r)', self, handle, reply_to, data) try: persist, fn = self._context._handle_map[handle] except KeyError: @@ -420,7 +434,7 @@ class Stream(BasicStream): del self._context._handle_map[handle] try: - fn(data) + fn(reply_to, data) except Exception: LOG.debug('%r._invoke(%r, %r): %r crashed', self, handle, data, fn) @@ -428,26 +442,31 @@ class Stream(BasicStream): """Transmit buffered messages.""" IOLOG.debug('%r.on_transmit()', self) written = os.write(self.transmit_side.fd, self._output_buf[:4096]) + IOLOG.debug('%r.on_transmit() -> len %d', self, written) self._output_buf = self._output_buf[written:] if not self._output_buf: broker.stop_transmit(self) - def _enqueue(self, handle, obj): - IOLOG.debug('%r._enqueue(%r, %r)', self, handle, obj) - try: - encoded = cPickle.dumps((handle, obj), protocol=2) - except cPickle.PicklingError, e: - encoded = cPickle.dumps((handle, CallError(e)), protocol=2) - - msg = struct.pack('>L', len(encoded)) + encoded + def _enqueue(self, handle, data, reply_to): + IOLOG.debug('%r._enqueue(%r, %r)', self, handle, data) + msg = struct.pack('>LLL', handle, reply_to, len(data)) + data self._whmac.update(msg) self._output_buf += self._whmac.digest() + msg self._context.broker.start_transmit(self) - def enqueue(self, handle, obj): + def enqueue_raw(self, handle, data, reply_to=0): + """Enqueue `data` to `handle`, and tell the broker we have output. May + be called from any thread.""" + self._context.broker.on_thread(self._enqueue, handle, data, reply_to) + + def enqueue(self, handle, obj, reply_to=0): """Enqueue `obj` to `handle`, and tell the broker we have output. May be called from any thread.""" - self._context.broker.on_thread(self._enqueue, handle, obj) + try: + encoded = cPickle.dumps(obj, protocol=2) + except cPickle.PicklingError, e: + encoded = cPickle.dumps(CallError(e), protocol=2) + self.enqueue_raw(handle, encoded, reply_to) def on_disconnect(self, broker): super(Stream, self).on_disconnect(broker) @@ -493,7 +512,7 @@ class Context(object): LOG.debug('%r.on_shutdown(%r)', self, broker) for handle, (persist, fn) in self._handle_map.iteritems(): LOG.debug('%r.on_disconnect(): killing %r: %r', self, handle, fn) - fn(_DEAD) + fn(0, _DEAD) def on_disconnect(self, broker): self.stream = None @@ -510,34 +529,49 @@ class Context(object): IOLOG.debug('%r.add_handle_cb(%r, %r, %r)', self, fn, handle, persist) self._handle_map[handle] = persist, fn - def enqueue(self, handle, obj): + def enqueue(self, handle, obj, reply_to=0): if self.stream: - self.stream.enqueue(handle, obj) + self.stream.enqueue(handle, obj, reply_to) - def enqueue_await_reply(self, handle, deadline, data): + def enqueue_await_reply_raw(self, handle, deadline, data): """Send `data` to `handle` and wait for a response with an optional - timeout. The message contains `(reply_to, data)`, where `reply_to` is - the handle on which this function expects its reply.""" + timeout.""" + if self.broker._thread == threading.currentThread(): # TODO + raise SystemError('Cannot making blocking call on broker thread') + reply_to = self.alloc_handle() LOG.debug('%r.enqueue_await_reply(%r, %r, %r) -> reply handle %d', self, handle, deadline, data, reply_to) queue = Queue.Queue() - self.add_handle_cb(queue.put, reply_to, persist=False) - self.stream.enqueue(handle, (reply_to,) + data) + receiver = lambda _reply_to, data: queue.put(data) + self.add_handle_cb(receiver, reply_to, persist=False) + self.stream.enqueue_raw(handle, data, reply_to) try: data = queue.get(True, deadline) except Queue.Empty: - self.broker.on_thread(self.stream.on_disconnect, self.broker) + # self.broker.on_thread(self.stream.on_disconnect, self.broker) raise TimeoutError('deadline exceeded.') if data == _DEAD: raise StreamError('lost connection during call.') - IOLOG.debug('%r._enqueue_await_reply(): got reply: %r', self, data) + IOLOG.debug('%r._enqueue_await_reply() -> %r', self, data) return data + def enqueue_await_reply(self, handle, deadline, obj): + """Like :py:meth:enqueue_await_reply_raw except that `data` is pickled + prior to sending, and the return value is unpickled on reception.""" + if self.broker._thread == threading.currentThread(): # TODO + raise SystemError('Cannot making blocking call on broker thread') + + encoded = cPickle.dumps(obj, protocol=2) + result = self.enqueue_await_reply_raw(handle, deadline, encoded) + decoded = self.broker.unpickle(result) + IOLOG.debug('%r.enqueue_await_reply() -> %r', self, decoded) + return decoded + def __repr__(self): bits = filter(None, (self.name, self.hostname, self.username)) return 'Context(%s)' % ', '.join(map(repr, bits)) @@ -648,6 +682,17 @@ class Broker(object): name='econtext-broker') self._thread.start() + _find_global = None + + def unpickle(self, data): + """Deserialize `data` into an object.""" + IOLOG.debug('%r.unpickle(%r)', self, data) + fp = cStringIO.StringIO(data) + unpickler = cPickle.Unpickler(fp) + if self._find_global: + unpickler.find_global = self._find_global + return unpickler.load() + def on_thread(self, func, *args, **kwargs): if threading.currentThread() == self._thread: func(*args, **kwargs) @@ -751,7 +796,7 @@ class Broker(object): LOG.error('%r: some streams did not close gracefully. ' 'The most likely cause for this is one or ' 'more child processes still connected to ' - 'ou stdout/stderr pipes.', self) + 'our stdout/stderr pipes.', self) for context in self._contexts.itervalues(): context.on_shutdown(self) @@ -857,7 +902,8 @@ class ExternalContext(object): def _dispatch_calls(self): for data in self.channel: LOG.debug('_dispatch_calls(%r)', data) - reply_to, with_context, modname, klass, func, args, kwargs = data + reply_to, data = data + with_context, modname, klass, func, args, kwargs = data if with_context: args = (self,) + args diff --git a/econtext/master.py b/econtext/master.py index f9942df6..934f4919 100644 --- a/econtext/master.py +++ b/econtext/master.py @@ -35,6 +35,11 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S) COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M) IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M) +PERMITTED_CLASSES = set([ + ('econtext.core', 'CallError'), + ('econtext.core', 'Dead'), +]) + def minimize_source(source): """Remove comments and docstrings from Python `source`, preserving line @@ -111,11 +116,11 @@ class LogForwarder(object): name = '%s.%s' % (RLOG.name, self._context.name) self._log = logging.getLogger(name) - def forward_log(self, data): + def forward_log(self, reply_to, data): if data == econtext.core._DEAD: return - name, level, s = data + name, level, s = data.split('\x00', 2) self._log.log(level, '%s: %s', name, s) @@ -184,12 +189,11 @@ class ModuleResponder(object): _get_module_via_sys_modules, _get_module_via_parent_enumeration] - def get_module(self, data): - LOG.debug('%r.get_module(%r)', self, data) - if data == econtext.core._DEAD: + def get_module(self, reply_to, fullname): + LOG.debug('%r.get_module(%r, %r)', self, reply_to, fullname) + if fullname == econtext.core._DEAD: return - reply_to, fullname = data try: for method in self.get_module_methods: tup = method(self, fullname) @@ -225,31 +229,11 @@ class Stream(econtext.core.Stream): #: The path to the remote Python interpreter. python_path = sys.executable - def __init__(self, context): - super(Stream, self).__init__(context) - self._permitted_classes = set([ - ('econtext.core', 'CallError'), - ('econtext.core', 'Dead'), - ]) - def on_shutdown(self, broker): """Request the slave gracefully shut itself down.""" LOG.debug('%r closing CALL_FUNCTION channel', self) self.enqueue(econtext.core.CALL_FUNCTION, econtext.core._DEAD) - def _find_global(self, module_name, class_name): - """Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted.""" - if (module_name, class_name) not in self._permitted_classes: - raise econtext.core.StreamError( - '%r attempted to unpickle %r in module %r', - self._context, class_name, module_name) - return getattr(sys.modules[module_name], class_name) - - def allow_class(self, module_name, class_name): - """Add `module_name` to the list of permitted modules.""" - self._permitted_modules.add((module_name, class_name)) - # base64'd and passed to 'python -c'. It forks, dups 0->100, creates a # pipe, then execs a new interpreter with a custom argv. 'CONTEXT_NAME' is # replaced with the context name. Optimized for size. @@ -320,6 +304,15 @@ class Stream(econtext.core.Stream): class Broker(econtext.core.Broker): shutdown_timeout = 5.0 + def _find_global(self, module_name, class_name): + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" + if (module_name, class_name) not in PERMITTED_CLASSES: + raise econtext.core.StreamError( + '%r attempted to unpickle %r in module %r', + self._context, class_name, module_name) + return getattr(sys.modules[module_name], class_name) + def __enter__(self): return self