pull/1342/merge
Marc Hartmayer 1 week ago committed by GitHub
commit 9564aad874
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -793,8 +793,19 @@ if PY3:
# In 3.x Unpickler is a class exposing find_class as an overridable, but it
# cannot be overridden without subclassing.
class _Unpickler(pickle.Unpickler):
def __init__(self, *args, insecure=False, **kwargs):
super().__init__(*args, **kwargs)
self.__insecure = insecure
def find_class(self, module, func):
return self.find_global(module, func)
try:
return self.find_global(module, func)
except Exception as error:
if not self.__insecure:
raise error
return super().find_class(module, func)
pickle__dumps = pickle.dumps
elif PY24:
# On Python 2.4, we must use a pure-Python pickler.
@ -978,7 +989,7 @@ class Message(object):
else:
raise ChannelError(ChannelError.remote_msg)
def unpickle(self, throw=True, throw_dead=True):
def unpickle(self, throw=True, throw_dead=True, *, insecure=False):
"""
Unpickle :attr:`data`, optionally raising any exceptions present.
@ -986,6 +997,9 @@ class Message(object):
If :data:`True`, raise exceptions, otherwise it is the caller's
responsibility.
:param bool insecure:
If :data:`True`, also use possibly unsecure unpickling methods.
:raises CallError:
The serialized data contained CallError exception.
:raises ChannelError:
@ -998,7 +1012,7 @@ class Message(object):
obj = self._unpickled
if obj is Message._unpickled:
fp = BytesIO(self.data)
unpickler = _Unpickler(fp, **self.UNPICKLER_KWARGS)
unpickler = _Unpickler(fp, insecure=insecure, **self.UNPICKLER_KWARGS)
unpickler.find_global = self._find_global
try:
# Must occur off the broker thread.
@ -3844,7 +3858,7 @@ class Dispatcher(object):
econtext.dispatcher._error_by_chain_id.pop(chain_id, None)
def _parse_request(self, msg):
data = msg.unpickle(throw=False)
data = msg.unpickle(throw=False, insecure=True)
_v and LOG.debug('%r: dispatching %r', self, data)
chain_id, modname, klass, func, args, kwargs = data

@ -42,6 +42,14 @@ class TargetClass:
def add_numbers_with_offset(cls, x, y):
return cls.offset + x + y
@classmethod
def passing_crazy_type(cls, crazy_cls):
return crazy_cls.__name__
@classmethod
def passing_crazy_type_instance(cls, crazy):
return crazy.__class__.__name__
class CallFunctionTest(testlib.RouterMixin, testlib.TestCase):
@ -58,6 +66,18 @@ class CallFunctionTest(testlib.RouterMixin, testlib.TestCase):
103,
)
def test_succeeds_passing_class(self):
self.assertEqual(
self.local.call(TargetClass.passing_crazy_type, CrazyType),
'CrazyType'
)
def test_succeeds_passing_class_instance(self):
self.assertEqual(
self.local.call(TargetClass.passing_crazy_type_instance, CrazyType()),
'CrazyType'
)
def test_crashes(self):
exc = self.assertRaises(mitogen.core.CallError,
lambda: self.local.call(function_that_fails))

Loading…
Cancel
Save