pickle: support Context(), use same unpickler everywhere.

* Support passing Context() objects in function calls and return values.
  Now the fakessh demo from the documentation index would work
  correctly.

* Since slaves can communicate with each other now, they should also use
  the same approach to unpickling as the master already used. Collapse
  away all the unpickle extension crap and hard-wire just the 3 types
  that support unpickling.
pull/35/head
David Wilson 7 years ago
parent ed90f3fa90
commit 066b39d570

@ -162,11 +162,27 @@ class Message(object):
reply_to = None reply_to = None
data = None data = None
router = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.src_id = mitogen.context_id self.src_id = mitogen.context_id
vars(self).update(kwargs) vars(self).update(kwargs)
_find_global = None def _unpickle_context(self, context_id, name):
return _unpickle_context(self.router, context_id, name)
def _find_global(self, module, func):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if module == __name__:
if func == '_unpickle_call_error':
return _unpickle_call_error
elif func == '_unpickle_dead':
return _unpickle_dead
elif func == '_unpickle_context':
return self._unpickle_context
raise StreamError('cannot unpickle %r/%r', module, func)
@classmethod @classmethod
def pickled(cls, obj, **kwargs): def pickled(cls, obj, **kwargs):
@ -182,8 +198,7 @@ class Message(object):
IOLOG.debug('%r.unpickle()', self) IOLOG.debug('%r.unpickle()', self)
fp = cStringIO.StringIO(self.data) fp = cStringIO.StringIO(self.data)
unpickler = cPickle.Unpickler(fp) unpickler = cPickle.Unpickler(fp)
if self._find_global: unpickler.find_global = self._find_global
unpickler.find_global = self._find_global
try: try:
return unpickler.load() return unpickler.load()
except (TypeError, ValueError), ex: except (TypeError, ValueError), ex:
@ -521,7 +536,6 @@ class Stream(BasicStream):
""" """
_input_buf = '' _input_buf = ''
_output_buf = '' _output_buf = ''
message_class = Message
def __init__(self, router, remote_id, key, **kwargs): def __init__(self, router, remote_id, key, **kwargs):
self._router = router self._router = router
@ -556,7 +570,10 @@ class Stream(BasicStream):
if len(self._input_buf) < self.HEADER_LEN: if len(self._input_buf) < self.HEADER_LEN:
return False return False
msg = self.message_class() msg = Message()
# To support unpickling Contexts.
msg.router = self._router
(msg.dst_id, msg.src_id, (msg.dst_id, msg.src_id,
msg.handle, msg.reply_to, msg_len) = struct.unpack( msg.handle, msg.reply_to, msg_len) = struct.unpack(
self.HEADER_FMT, self.HEADER_FMT,
@ -628,6 +645,9 @@ class Context(object):
self.name = name self.name = name
self.key = key or ('%016x' % random.getrandbits(128)) self.key = key or ('%016x' % random.getrandbits(128))
def __reduce__(self):
return _unpickle_context, (self.context_id, self.name)
def on_disconnect(self, broker): def on_disconnect(self, broker):
LOG.debug('Parent stream is gone, dying.') LOG.debug('Parent stream is gone, dying.')
fire(self, 'disconnect') fire(self, 'disconnect')
@ -672,6 +692,13 @@ class Context(object):
return 'Context(%s, %r)' % (self.context_id, self.name) return 'Context(%s, %r)' % (self.context_id, self.name)
def _unpickle_context(router, context_id, name):
assert isinstance(router, Router)
assert isinstance(context_id, (int, long)) and context_id > 0
assert type(name) is str and len(name) < 100
return Context(router, context_id, name)
class Waker(BasicStream): class Waker(BasicStream):
""" """
:py:class:`BasicStream` subclass implementing the :py:class:`BasicStream` subclass implementing the

@ -38,11 +38,6 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S)
COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M) COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M)
IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M) IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M)
PERMITTED_CLASSES = set([
('mitogen.core', '_unpickle_call_error'),
('mitogen.core', '_unpickle_dead'),
])
def minimize_source(source): def minimize_source(source):
"""Remove comments and docstrings from Python `source`, preserving line """Remove comments and docstrings from Python `source`, preserving line
@ -316,27 +311,10 @@ class ModuleForwarder(object):
) )
class Message(mitogen.core.Message):
"""
Message subclass that controls unpickling.
"""
def _find_global(self, module_name, class_name):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in PERMITTED_CLASSES:
raise mitogen.core.StreamError(
'attempted to unpickle %r in module %r',
class_name, module_name
)
return getattr(sys.modules[module_name], class_name)
class Stream(mitogen.core.Stream): class Stream(mitogen.core.Stream):
""" """
Base for streams capable of starting new slaves. Base for streams capable of starting new slaves.
""" """
message_class = Message
#: The path to the remote Python interpreter. #: The path to the remote Python interpreter.
python_path = 'python2.7' python_path = 'python2.7'

@ -20,6 +20,10 @@ def func_returns_dead():
return mitogen.core._DEAD return mitogen.core._DEAD
def func_accepts_returns_context(context):
return context
class CallFunctionTest(unittest.TestCase): class CallFunctionTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
@ -59,8 +63,7 @@ class CallFunctionTest(unittest.TestCase):
pass pass
assert e[0] == ( assert e[0] == (
"attempted to unpickle 'CrazyType' " "attempted unpickle from 'call_function_test'"
"in module 'call_function_test'"
) )
def test_returns_dead(self): def test_returns_dead(self):
@ -75,3 +78,9 @@ class CallFunctionTest(unittest.TestCase):
def test_aborted_on_local_broker_shutdown(self): def test_aborted_on_local_broker_shutdown(self):
assert 0, 'todo' assert 0, 'todo'
def test_accepts_returns_context(self):
context = self.local.call(func_accepts_returns_context, self.local)
assert context is not self.local
assert context.context_id == self.local.context_id
assert context.name == self.local.name

Loading…
Cancel
Save