Fix module loader deadlock

Stop using cPickle on the broker thread where it is not known whether
the pickle data would cause the import machinery to be invoked, which
currently relies on blocking calls. Huge mess but it works.

This is due to:

        context.call(some.module.func, another.module.func)

We stringify ("some.module", "func"), but the reference to
another.module.func is passed into the pickle machinery, and there's no
way to generically stringify all function references in user data for
reification on the main thread, without doing something like this
instead.
pull/35/head
David Wilson 7 years ago
parent 48b4ac17b7
commit e374f85888

@ -231,19 +231,19 @@ Stream Protocol
Once connected, a basic framing protocol is used to communicate between
master and slave:
+------------+-------+-----------------------------------------------------+
| Field | Size | Description |
+============+=======+=====================================================+
| ``hmac`` | 20 | SHA-1 MAC over (``length || data``) |
+------------+-------+-----------------------------------------------------+
| ``length`` | 4 | Message length |
+------------+-------+-----------------------------------------------------+
| ``data`` | n/a | Pickled message data. |
+------------+-------+-----------------------------------------------------+
The ``data`` component always consists of a 2-tuple, `(handle, data)`, where
``handle`` is an integer describing the message target and ``data`` is the
value to be delivered to the target.
+----------------+------+------------------------------------------------------+
| Field | Size | Description |
+================+======+======================================================+
| ``hmac`` | 20 | SHA-1 over (``context_id || handle || len || data``) |
+----------------+------+------------------------------------------------------+
| ``handle`` | 4 | Integer target handle in recipient. |
+----------------+------+------------------------------------------------------+
| ``reply_to`` | 4 | Integer response target ID. |
+----------------+------+------------------------------------------------------+
| ``length`` | 4 | Message length |
+----------------+------+------------------------------------------------------+
| ``data`` | n/a | Pickled message data. |
+----------------+------+------------------------------------------------------+
Masters listen on the following handles:
@ -304,6 +304,14 @@ cannot be used securely, however few of those accounts appear to be expert, and
none mention any additional attacks that would not be prevented by using a
restrictive class whitelist.
.. note::
Since unpickling may trigger module loads, it is not possible to
deserialize data on the broker thread, as this will result in recursion
leading to a deadlock. Therefore any internal services (module loader,
logging forwarder, etc.) must rely on simple string formats, or only
perform serialization from within the broker thread.
Use of HMAC
###########

@ -96,10 +96,10 @@ class Channel(object):
self._queue = Queue.Queue()
self._context.add_handle_cb(self._receive, handle)
def _receive(self, data):
def _receive(self, reply_to, data):
"""Callback from the Stream; appends data to the internal queue."""
IOLOG.debug('%r._receive(%r)', self, data)
self._queue.put(data)
self._queue.put((reply_to, data))
def close(self):
"""Indicate this channel is closed to the remote side."""
@ -115,13 +115,21 @@ class Channel(object):
"""Receive an object, or ``None`` if `timeout` is reached."""
IOLOG.debug('%r.on_receive(timeout=%r)', self, timeout)
try:
data = self._queue.get(True, timeout)
reply_to, data = self._queue.get(True, timeout)
except Queue.Empty:
return
IOLOG.debug('%r.on_receive() got %r', self, data)
# Must occur off the broker thread.
if not isinstance(data, Dead):
try:
data = self._context.broker.unpickle(data)
except (TypeError, ValueError), ex:
raise StreamError('invalid message: %s', ex)
if not isinstance(data, (Dead, CallError)):
return data
return reply_to, data
elif data == _DEAD:
raise ChannelError('Channel is closed.')
else:
@ -153,6 +161,7 @@ class Importer(object):
'econtext.compat',
'econtext.compat.pkgutil',
'econtext.master',
'econtext.proxy',
'econtext.ssh',
'econtext.sudo',
'econtext.utils',
@ -168,6 +177,7 @@ class Importer(object):
return None
self.tls.running = True
fullname = fullname.rstrip('.')
try:
pkgname, _, _ = fullname.rpartition('.')
LOG.debug('%r.find_module(%r)', self, fullname)
@ -195,8 +205,9 @@ class Importer(object):
try:
ret = self._cache[fullname]
except KeyError:
ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,))
self._cache[fullname] = ret
self._cache[fullname] = ret = cPickle.loads(
self._context.enqueue_await_reply_raw(GET_MODULE, None, fullname)
)
if ret is None:
raise ImportError('Master does not have %r' % (fullname,))
@ -238,7 +249,8 @@ class LogHandler(logging.Handler):
self.local.in_emit = True
try:
msg = self.format(rec)
self.context.enqueue(FORWARD_LOG, (rec.name, rec.levelno, msg))
encoded = '%s\x00%s\x00%s' % (rec.name, rec.levelno, msg)
self.context.enqueue(FORWARD_LOG, encoded)
finally:
self.local.in_emit = False
@ -346,17 +358,6 @@ class Stream(BasicStream):
self._rhmac = hmac.new(context.key, digestmod=sha)
self._whmac = self._rhmac.copy()
_find_global = None
def unpickle(self, data):
"""Deserialize `data` into an object."""
IOLOG.debug('%r.unpickle(%r)', self, data)
fp = cStringIO.StringIO(data)
unpickler = cPickle.Unpickler(fp)
if self._find_global:
unpickler.find_global = self._find_global
return unpickler.load()
def on_receive(self, broker):
"""Handle the next complete message on the stream. Raise
:py:class:`StreamError` on failure."""
@ -364,7 +365,9 @@ class Stream(BasicStream):
try:
buf = os.read(self.receive_side.fd, 4096)
IOLOG.debug('%r.on_receive() -> len %d', self, len(buf))
except OSError, e:
IOLOG.debug('%r.on_receive() -> OSError: %s', self, e)
# When connected over a TTY (i.e. sudo), disconnection of the
# remote end is signalled by EIO, rather than an empty read like
# sockets or pipes. Ideally this will be replaced later by a
@ -382,17 +385,32 @@ class Stream(BasicStream):
if not buf:
return self.on_disconnect(broker)
HEADER_FMT = (
'>'
'20s' # msg_mac
'L' # handle
'L' # reply_to
'L' # msg_len
)
HEADER_LEN = struct.calcsize(HEADER_FMT)
MAC_LEN = sha.digest_size
def _receive_one(self):
if len(self._input_buf) < 24:
if len(self._input_buf) < self.HEADER_LEN:
return False
msg_mac = self._input_buf[:20]
msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
if len(self._input_buf)-24 < msg_len:
msg_mac, handle, reply_to, msg_len = struct.unpack(
self.HEADER_FMT,
self._input_buf[:self.HEADER_LEN]
)
if (len(self._input_buf) - self.HEADER_LEN) < msg_len:
IOLOG.debug('Input too short')
return False
self._rhmac.update(self._input_buf[20:msg_len+24])
self._rhmac.update(self._input_buf[
self.MAC_LEN : (msg_len + self.HEADER_LEN)
])
expected_mac = self._rhmac.digest()
if msg_mac != expected_mac:
raise StreamError('bad MAC: %r != got %r; %r',
@ -400,17 +418,13 @@ class Stream(BasicStream):
expected_mac.encode('hex'),
self._input_buf[24:msg_len+24])
try:
handle, data = self.unpickle(self._input_buf[24:msg_len+24])
except (TypeError, ValueError), ex:
raise StreamError('invalid message: %s', ex)
self._input_buf = self._input_buf[msg_len+24:]
self._invoke(handle, data)
data = self._input_buf[self.HEADER_LEN:self.HEADER_LEN+msg_len]
self._input_buf = self._input_buf[self.HEADER_LEN+msg_len:]
self._invoke(handle, reply_to, data)
return True
def _invoke(self, handle, data):
IOLOG.debug('%r._invoke(%r, %r)', self, handle, data)
def _invoke(self, handle, reply_to, data):
IOLOG.debug('%r._invoke(%r, %r, %r)', self, handle, reply_to, data)
try:
persist, fn = self._context._handle_map[handle]
except KeyError:
@ -420,7 +434,7 @@ class Stream(BasicStream):
del self._context._handle_map[handle]
try:
fn(data)
fn(reply_to, data)
except Exception:
LOG.debug('%r._invoke(%r, %r): %r crashed', self, handle, data, fn)
@ -428,26 +442,31 @@ class Stream(BasicStream):
"""Transmit buffered messages."""
IOLOG.debug('%r.on_transmit()', self)
written = os.write(self.transmit_side.fd, self._output_buf[:4096])
IOLOG.debug('%r.on_transmit() -> len %d', self, written)
self._output_buf = self._output_buf[written:]
if not self._output_buf:
broker.stop_transmit(self)
def _enqueue(self, handle, obj):
IOLOG.debug('%r._enqueue(%r, %r)', self, handle, obj)
try:
encoded = cPickle.dumps((handle, obj), protocol=2)
except cPickle.PicklingError, e:
encoded = cPickle.dumps((handle, CallError(e)), protocol=2)
msg = struct.pack('>L', len(encoded)) + encoded
def _enqueue(self, handle, data, reply_to):
IOLOG.debug('%r._enqueue(%r, %r)', self, handle, data)
msg = struct.pack('>LLL', handle, reply_to, len(data)) + data
self._whmac.update(msg)
self._output_buf += self._whmac.digest() + msg
self._context.broker.start_transmit(self)
def enqueue(self, handle, obj):
def enqueue_raw(self, handle, data, reply_to=0):
"""Enqueue `data` to `handle`, and tell the broker we have output. May
be called from any thread."""
self._context.broker.on_thread(self._enqueue, handle, data, reply_to)
def enqueue(self, handle, obj, reply_to=0):
"""Enqueue `obj` to `handle`, and tell the broker we have output. May
be called from any thread."""
self._context.broker.on_thread(self._enqueue, handle, obj)
try:
encoded = cPickle.dumps(obj, protocol=2)
except cPickle.PicklingError, e:
encoded = cPickle.dumps(CallError(e), protocol=2)
self.enqueue_raw(handle, encoded, reply_to)
def on_disconnect(self, broker):
super(Stream, self).on_disconnect(broker)
@ -493,7 +512,7 @@ class Context(object):
LOG.debug('%r.on_shutdown(%r)', self, broker)
for handle, (persist, fn) in self._handle_map.iteritems():
LOG.debug('%r.on_disconnect(): killing %r: %r', self, handle, fn)
fn(_DEAD)
fn(0, _DEAD)
def on_disconnect(self, broker):
self.stream = None
@ -510,34 +529,49 @@ class Context(object):
IOLOG.debug('%r.add_handle_cb(%r, %r, %r)', self, fn, handle, persist)
self._handle_map[handle] = persist, fn
def enqueue(self, handle, obj):
def enqueue(self, handle, obj, reply_to=0):
if self.stream:
self.stream.enqueue(handle, obj)
self.stream.enqueue(handle, obj, reply_to)
def enqueue_await_reply(self, handle, deadline, data):
def enqueue_await_reply_raw(self, handle, deadline, data):
"""Send `data` to `handle` and wait for a response with an optional
timeout. The message contains `(reply_to, data)`, where `reply_to` is
the handle on which this function expects its reply."""
timeout."""
if self.broker._thread == threading.currentThread(): # TODO
raise SystemError('Cannot making blocking call on broker thread')
reply_to = self.alloc_handle()
LOG.debug('%r.enqueue_await_reply(%r, %r, %r) -> reply handle %d',
self, handle, deadline, data, reply_to)
queue = Queue.Queue()
self.add_handle_cb(queue.put, reply_to, persist=False)
self.stream.enqueue(handle, (reply_to,) + data)
receiver = lambda _reply_to, data: queue.put(data)
self.add_handle_cb(receiver, reply_to, persist=False)
self.stream.enqueue_raw(handle, data, reply_to)
try:
data = queue.get(True, deadline)
except Queue.Empty:
self.broker.on_thread(self.stream.on_disconnect, self.broker)
# self.broker.on_thread(self.stream.on_disconnect, self.broker)
raise TimeoutError('deadline exceeded.')
if data == _DEAD:
raise StreamError('lost connection during call.')
IOLOG.debug('%r._enqueue_await_reply(): got reply: %r', self, data)
IOLOG.debug('%r._enqueue_await_reply() -> %r', self, data)
return data
def enqueue_await_reply(self, handle, deadline, obj):
"""Like :py:meth:enqueue_await_reply_raw except that `data` is pickled
prior to sending, and the return value is unpickled on reception."""
if self.broker._thread == threading.currentThread(): # TODO
raise SystemError('Cannot making blocking call on broker thread')
encoded = cPickle.dumps(obj, protocol=2)
result = self.enqueue_await_reply_raw(handle, deadline, encoded)
decoded = self.broker.unpickle(result)
IOLOG.debug('%r.enqueue_await_reply() -> %r', self, decoded)
return decoded
def __repr__(self):
bits = filter(None, (self.name, self.hostname, self.username))
return 'Context(%s)' % ', '.join(map(repr, bits))
@ -648,6 +682,17 @@ class Broker(object):
name='econtext-broker')
self._thread.start()
_find_global = None
def unpickle(self, data):
"""Deserialize `data` into an object."""
IOLOG.debug('%r.unpickle(%r)', self, data)
fp = cStringIO.StringIO(data)
unpickler = cPickle.Unpickler(fp)
if self._find_global:
unpickler.find_global = self._find_global
return unpickler.load()
def on_thread(self, func, *args, **kwargs):
if threading.currentThread() == self._thread:
func(*args, **kwargs)
@ -751,7 +796,7 @@ class Broker(object):
LOG.error('%r: some streams did not close gracefully. '
'The most likely cause for this is one or '
'more child processes still connected to '
'ou stdout/stderr pipes.', self)
'our stdout/stderr pipes.', self)
for context in self._contexts.itervalues():
context.on_shutdown(self)
@ -857,7 +902,8 @@ class ExternalContext(object):
def _dispatch_calls(self):
for data in self.channel:
LOG.debug('_dispatch_calls(%r)', data)
reply_to, with_context, modname, klass, func, args, kwargs = data
reply_to, data = data
with_context, modname, klass, func, args, kwargs = data
if with_context:
args = (self,) + args

@ -35,6 +35,11 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S)
COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M)
IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M)
PERMITTED_CLASSES = set([
('econtext.core', 'CallError'),
('econtext.core', 'Dead'),
])
def minimize_source(source):
"""Remove comments and docstrings from Python `source`, preserving line
@ -111,11 +116,11 @@ class LogForwarder(object):
name = '%s.%s' % (RLOG.name, self._context.name)
self._log = logging.getLogger(name)
def forward_log(self, data):
def forward_log(self, reply_to, data):
if data == econtext.core._DEAD:
return
name, level, s = data
name, level, s = data.split('\x00', 2)
self._log.log(level, '%s: %s', name, s)
@ -184,12 +189,11 @@ class ModuleResponder(object):
_get_module_via_sys_modules,
_get_module_via_parent_enumeration]
def get_module(self, data):
LOG.debug('%r.get_module(%r)', self, data)
if data == econtext.core._DEAD:
def get_module(self, reply_to, fullname):
LOG.debug('%r.get_module(%r, %r)', self, reply_to, fullname)
if fullname == econtext.core._DEAD:
return
reply_to, fullname = data
try:
for method in self.get_module_methods:
tup = method(self, fullname)
@ -225,31 +229,11 @@ class Stream(econtext.core.Stream):
#: The path to the remote Python interpreter.
python_path = sys.executable
def __init__(self, context):
super(Stream, self).__init__(context)
self._permitted_classes = set([
('econtext.core', 'CallError'),
('econtext.core', 'Dead'),
])
def on_shutdown(self, broker):
"""Request the slave gracefully shut itself down."""
LOG.debug('%r closing CALL_FUNCTION channel', self)
self.enqueue(econtext.core.CALL_FUNCTION, econtext.core._DEAD)
def _find_global(self, module_name, class_name):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in self._permitted_classes:
raise econtext.core.StreamError(
'%r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def allow_class(self, module_name, class_name):
"""Add `module_name` to the list of permitted modules."""
self._permitted_modules.add((module_name, class_name))
# base64'd and passed to 'python -c'. It forks, dups 0->100, creates a
# pipe, then execs a new interpreter with a custom argv. 'CONTEXT_NAME' is
# replaced with the context name. Optimized for size.
@ -320,6 +304,15 @@ class Stream(econtext.core.Stream):
class Broker(econtext.core.Broker):
shutdown_timeout = 5.0
def _find_global(self, module_name, class_name):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in PERMITTED_CLASSES:
raise econtext.core.StreamError(
'%r attempted to unpickle %r in module %r',
self._context, class_name, module_name)
return getattr(sys.modules[module_name], class_name)
def __enter__(self):
return self

Loading…
Cancel
Save