diff --git a/mitogen/core.py b/mitogen/core.py index dba5648b..9a9cc0a1 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -162,11 +162,27 @@ class Message(object): reply_to = None data = None + router = None + def __init__(self, **kwargs): self.src_id = mitogen.context_id vars(self).update(kwargs) - _find_global = None + def _unpickle_context(self, context_id, name): + return _unpickle_context(self.router, context_id, name) + + def _find_global(self, module, func): + """Return the class implementing `module_name.class_name` or raise + `StreamError` if the module is not whitelisted.""" + if module == __name__: + if func == '_unpickle_call_error': + return _unpickle_call_error + elif func == '_unpickle_dead': + return _unpickle_dead + elif func == '_unpickle_context': + return self._unpickle_context + + raise StreamError('cannot unpickle %r/%r', module, func) @classmethod def pickled(cls, obj, **kwargs): @@ -182,8 +198,7 @@ class Message(object): IOLOG.debug('%r.unpickle()', self) fp = cStringIO.StringIO(self.data) unpickler = cPickle.Unpickler(fp) - if self._find_global: - unpickler.find_global = self._find_global + unpickler.find_global = self._find_global try: return unpickler.load() except (TypeError, ValueError), ex: @@ -521,7 +536,6 @@ class Stream(BasicStream): """ _input_buf = '' _output_buf = '' - message_class = Message def __init__(self, router, remote_id, key, **kwargs): self._router = router @@ -556,7 +570,10 @@ class Stream(BasicStream): if len(self._input_buf) < self.HEADER_LEN: return False - msg = self.message_class() + msg = Message() + # To support unpickling Contexts. + msg.router = self._router + (msg.dst_id, msg.src_id, msg.handle, msg.reply_to, msg_len) = struct.unpack( self.HEADER_FMT, @@ -628,6 +645,9 @@ class Context(object): self.name = name self.key = key or ('%016x' % random.getrandbits(128)) + def __reduce__(self): + return _unpickle_context, (self.context_id, self.name) + def on_disconnect(self, broker): LOG.debug('Parent stream is gone, dying.') fire(self, 'disconnect') @@ -672,6 +692,13 @@ class Context(object): return 'Context(%s, %r)' % (self.context_id, self.name) +def _unpickle_context(router, context_id, name): + assert isinstance(router, Router) + assert isinstance(context_id, (int, long)) and context_id > 0 + assert type(name) is str and len(name) < 100 + return Context(router, context_id, name) + + class Waker(BasicStream): """ :py:class:`BasicStream` subclass implementing the diff --git a/mitogen/master.py b/mitogen/master.py index 5d8e18d9..f9e3341d 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -38,11 +38,6 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S) COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M) IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M) -PERMITTED_CLASSES = set([ - ('mitogen.core', '_unpickle_call_error'), - ('mitogen.core', '_unpickle_dead'), -]) - def minimize_source(source): """Remove comments and docstrings from Python `source`, preserving line @@ -316,27 +311,10 @@ class ModuleForwarder(object): ) -class Message(mitogen.core.Message): - """ - Message subclass that controls unpickling. - """ - def _find_global(self, module_name, class_name): - """Return the class implementing `module_name.class_name` or raise - `StreamError` if the module is not whitelisted.""" - if (module_name, class_name) not in PERMITTED_CLASSES: - raise mitogen.core.StreamError( - 'attempted to unpickle %r in module %r', - class_name, module_name - ) - return getattr(sys.modules[module_name], class_name) - - class Stream(mitogen.core.Stream): """ Base for streams capable of starting new slaves. """ - message_class = Message - #: The path to the remote Python interpreter. python_path = 'python2.7' diff --git a/tests/call_function_test.py b/tests/call_function_test.py index 7c63dd70..9cbfe23e 100644 --- a/tests/call_function_test.py +++ b/tests/call_function_test.py @@ -20,6 +20,10 @@ def func_returns_dead(): return mitogen.core._DEAD +def func_accepts_returns_context(context): + return context + + class CallFunctionTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -59,8 +63,7 @@ class CallFunctionTest(unittest.TestCase): pass assert e[0] == ( - "attempted to unpickle 'CrazyType' " - "in module 'call_function_test'" + "attempted unpickle from 'call_function_test'" ) def test_returns_dead(self): @@ -75,3 +78,9 @@ class CallFunctionTest(unittest.TestCase): def test_aborted_on_local_broker_shutdown(self): assert 0, 'todo' + + def test_accepts_returns_context(self): + context = self.local.call(func_accepts_returns_context, self.local) + assert context is not self.local + assert context.context_id == self.local.context_id + assert context.name == self.local.name