pickle: support Context(), use same unpickler everywhere.

* Support passing Context() objects in function calls and return values.
  Now the fakessh demo from the documentation index would work
  correctly.

* Since slaves can communicate with each other now, they should also use
  the same approach to unpickling as the master already used. Collapse
  away all the unpickle extension crap and hard-wire just the 3 types
  that support unpickling.
wip-fakessh-exit-status
David Wilson 7 years ago
parent 11acc031a9
commit be9e55fe8c

@ -162,11 +162,27 @@ class Message(object):
reply_to = None
data = None
router = None
def __init__(self, **kwargs):
self.src_id = mitogen.context_id
vars(self).update(kwargs)
_find_global = None
def _unpickle_context(self, context_id, name):
return _unpickle_context(self.router, context_id, name)
def _find_global(self, module, func):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if module == __name__:
if func == '_unpickle_call_error':
return _unpickle_call_error
elif func == '_unpickle_dead':
return _unpickle_dead
elif func == '_unpickle_context':
return self._unpickle_context
raise StreamError('cannot unpickle %r/%r', module, func)
@classmethod
def pickled(cls, obj, **kwargs):
@ -182,8 +198,7 @@ class Message(object):
IOLOG.debug('%r.unpickle()', self)
fp = cStringIO.StringIO(self.data)
unpickler = cPickle.Unpickler(fp)
if self._find_global:
unpickler.find_global = self._find_global
unpickler.find_global = self._find_global
try:
return unpickler.load()
except (TypeError, ValueError), ex:
@ -521,7 +536,6 @@ class Stream(BasicStream):
"""
_input_buf = ''
_output_buf = ''
message_class = Message
def __init__(self, router, remote_id, key, **kwargs):
self._router = router
@ -556,7 +570,10 @@ class Stream(BasicStream):
if len(self._input_buf) < self.HEADER_LEN:
return False
msg = self.message_class()
msg = Message()
# To support unpickling Contexts.
msg.router = self._router
(msg.dst_id, msg.src_id,
msg.handle, msg.reply_to, msg_len) = struct.unpack(
self.HEADER_FMT,
@ -628,6 +645,9 @@ class Context(object):
self.name = name
self.key = key or ('%016x' % random.getrandbits(128))
def __reduce__(self):
return _unpickle_context, (self.context_id, self.name)
def on_disconnect(self, broker):
LOG.debug('Parent stream is gone, dying.')
fire(self, 'disconnect')
@ -672,6 +692,13 @@ class Context(object):
return 'Context(%s, %r)' % (self.context_id, self.name)
def _unpickle_context(router, context_id, name):
assert isinstance(router, Router)
assert isinstance(context_id, (int, long)) and context_id > 0
assert type(name) is str and len(name) < 100
return Context(router, context_id, name)
class Waker(BasicStream):
"""
:py:class:`BasicStream` subclass implementing the

@ -38,11 +38,6 @@ DOCSTRING_RE = re.compile(r'""".+?"""', re.M | re.S)
COMMENT_RE = re.compile(r'^[ ]*#[^\n]*$', re.M)
IOLOG_RE = re.compile(r'^[ ]*IOLOG.debug\(.+?\)$', re.M)
PERMITTED_CLASSES = set([
('mitogen.core', '_unpickle_call_error'),
('mitogen.core', '_unpickle_dead'),
])
def minimize_source(source):
"""Remove comments and docstrings from Python `source`, preserving line
@ -316,27 +311,10 @@ class ModuleForwarder(object):
)
class Message(mitogen.core.Message):
"""
Message subclass that controls unpickling.
"""
def _find_global(self, module_name, class_name):
"""Return the class implementing `module_name.class_name` or raise
`StreamError` if the module is not whitelisted."""
if (module_name, class_name) not in PERMITTED_CLASSES:
raise mitogen.core.StreamError(
'attempted to unpickle %r in module %r',
class_name, module_name
)
return getattr(sys.modules[module_name], class_name)
class Stream(mitogen.core.Stream):
"""
Base for streams capable of starting new slaves.
"""
message_class = Message
#: The path to the remote Python interpreter.
python_path = 'python2.7'

@ -20,6 +20,10 @@ def func_returns_dead():
return mitogen.core._DEAD
def func_accepts_returns_context(context):
return context
class CallFunctionTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
@ -59,8 +63,7 @@ class CallFunctionTest(unittest.TestCase):
pass
assert e[0] == (
"attempted to unpickle 'CrazyType' "
"in module 'call_function_test'"
"attempted unpickle from 'call_function_test'"
)
def test_returns_dead(self):
@ -75,3 +78,9 @@ class CallFunctionTest(unittest.TestCase):
def test_aborted_on_local_broker_shutdown(self):
assert 0, 'todo'
def test_accepts_returns_context(self):
context = self.local.call(func_accepts_returns_context, self.local)
assert context is not self.local
assert context.context_id == self.local.context_id
assert context.name == self.local.name

Loading…
Cancel
Save