diff --git a/docs/api.rst b/docs/api.rst index 50c5bdab..2c4dc42f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -412,7 +412,7 @@ Router Class receive side to the I/O multiplexer. This This method remains public for now while hte design has not yet settled. - .. method:: add_handler (fn, handle=None, persist=True, respondent=None) + .. method:: add_handler (fn, handle=None, persist=True, respondent=None, policy=None) Invoke `fn(msg)` for each Message sent to `handle` from this context. Unregister after one invocation if `persist` is ``False``. If `handle` @@ -435,6 +435,28 @@ Router Class In future `respondent` will likely also be used to prevent other contexts from sending messages to the handle. + :param function policy: + Function invoked as `policy(msg, stream)` where `msg` is a + :py:class:`mitogen.core.Message` about to be delivered, and + `stream` is the :py:class:`mitogen.core.Stream` on which it was + received. The function must return :py:data:`True`, otherwise an + error is logged and delivery is refused. + + Two built-in policy functions exist: + + * :py:func:`mitogen.core.has_parent_authority`: requires the + message arrived from a parent context, or a context acting with a + parent context's authority (``auth_id``). + + * :py:func:`mitogen.parent.is_immediate_child`: requires the + message arrived from an immediately connected child, for use in + messaging patterns where either something becomes buggy or + insecure by permitting indirect upstream communication. + + In case of refusal, and the message's ``reply_to`` field is + nonzero, a :py:class:`mitogen.core.CallError` is delivered to the + sender indicating refusal occurred. + :return: `handle`, or if `handle` was ``None``, the newly allocated handle. diff --git a/mitogen/core.py b/mitogen/core.py index 4f2bbd75..2085a850 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -157,6 +157,10 @@ def _unpickle_dead(): _DEAD = Dead() +def has_parent_authority(msg, _stream): + return msg.auth_id in mitogen.parent_ids + + def listen(obj, name, func): signals = vars(obj).setdefault('_signals', {}) signals.setdefault(name, []).append(func) @@ -407,11 +411,17 @@ class Receiver(object): notify = None raise_channelerror = True - def __init__(self, router, handle=None, persist=True, respondent=None): + def __init__(self, router, handle=None, persist=True, + respondent=None, policy=None): self.router = router self.handle = handle # Avoid __repr__ crash in add_handler() - self.handle = router.add_handler(self._on_receive, handle, - persist, respondent) + self.handle = router.add_handler( + fn=self._on_receive, + handle=handle, + policy=policy, + persist=persist, + respondent=respondent, + ) self._latch = Latch() def __repr__(self): @@ -497,7 +507,11 @@ class Importer(object): # Presence of an entry in this map indicates in-flight GET_MODULE. self._callbacks = {} - router.add_handler(self._on_load_module, LOAD_MODULE) + router.add_handler( + fn=self._on_load_module, + handle=LOAD_MODULE, + policy=has_parent_authority, + ) self._cache = {} if core_src: self._cache['mitogen.core'] = ( @@ -1235,7 +1249,7 @@ class Router(object): def _cleanup_handlers(self): while self._handle_map: - _, (_, func) = self._handle_map.popitem() + _, (_, func, _) = self._handle_map.popitem() func(_DEAD) def register(self, context, stream): @@ -1245,18 +1259,22 @@ class Router(object): self.broker.start_receive(stream) listen(stream, 'disconnect', lambda: self.on_stream_disconnect(stream)) - def add_handler(self, fn, handle=None, persist=True, respondent=None): + def add_handler(self, fn, handle=None, persist=True, + policy=None, respondent=None): handle = handle or self._last_handle.next() _vv and IOLOG.debug('%r.add_handler(%r, %r, %r)', self, fn, handle, persist) - self._handle_map[handle] = persist, fn if respondent: + assert policy is None + def policy(msg, _stream): + return msg.src_id == respondent.context_id def on_disconnect(): if handle in self._handle_map: fn(_DEAD) del self._handle_map[handle] listen(respondent, 'disconnect', on_disconnect) + self._handle_map[handle] = persist, fn, policy return handle def on_shutdown(self, broker): @@ -1268,14 +1286,26 @@ class Router(object): _v and LOG.debug('%r.on_shutdown(): killing %r: %r', self, handle, fn) fn(_DEAD) - def _invoke(self, msg): + refused_msg = 'Refused by policy.' + + def _invoke(self, msg, stream): #IOLOG.debug('%r._invoke(%r)', self, msg) try: - persist, fn = self._handle_map[msg.handle] + persist, fn, policy = self._handle_map[msg.handle] except KeyError: LOG.error('%r: invalid handle: %r', self, msg) return + if policy and not policy(msg, stream): + LOG.error('%r: policy refused message: %r', self, msg) + if msg.reply_to: + self.route(Message.pickled( + CallError(self.refused_msg), + dst_id=msg.src_id, + handle=msg.reply_to + )) + return + if not persist: del self._handle_map[msg.handle] @@ -1311,7 +1341,7 @@ class Router(object): msg.auth_id = stream.auth_id if msg.dst_id == mitogen.context_id: - return self._invoke(msg) + return self._invoke(msg, stream) stream = self._stream_by_id.get(msg.dst_id) if stream is None: @@ -1456,10 +1486,8 @@ class ExternalContext(object): def _on_shutdown_msg(self, msg): _v and LOG.debug('_on_shutdown_msg(%r)', msg) - if msg != _DEAD and msg.auth_id not in mitogen.parent_ids: - LOG.warning('Ignoring SHUTDOWN from non-parent: %r', msg) - return - self.broker.shutdown() + if msg != _DEAD: + self.broker.shutdown() def _on_parent_disconnect(self): _v and LOG.debug('%r: parent stream is gone, dying.', self) @@ -1473,14 +1501,20 @@ class ExternalContext(object): enable_profiling() self.broker = Broker() self.router = Router(self.broker) - self.router.add_handler(self._on_shutdown_msg, SHUTDOWN) + self.router.add_handler( + fn=self._on_shutdown_msg, + handle=SHUTDOWN, + policy=has_parent_authority, + ) self.master = Context(self.router, 0, 'master') if parent_id == 0: self.parent = self.master else: self.parent = Context(self.router, parent_id, 'parent') - self.channel = Receiver(self.router, CALL_FUNCTION) + self.channel = Receiver(router=self.router, + handle=CALL_FUNCTION, + policy=has_parent_authority) self.stream = Stream(self.router, parent_id) self.stream.name = 'parent' self.stream.accept(in_fd, out_fd) @@ -1576,8 +1610,6 @@ class ExternalContext(object): def _dispatch_one(self, msg): data = msg.unpickle(throw=False) _v and LOG.debug('_dispatch_calls(%r)', data) - if msg.auth_id not in mitogen.parent_ids: - LOG.warning('CALL_FUNCTION from non-parent %r', msg.auth_id) modname, klass, func, args, kwargs = data obj = __import__(modname, {}, {}, ['']) diff --git a/mitogen/master.py b/mitogen/master.py index 0cf5d451..dca4eb46 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -288,7 +288,10 @@ class LogForwarder(object): def __init__(self, router): self._router = router self._cache = {} - router.add_handler(self._on_forward_log, mitogen.core.FORWARD_LOG) + router.add_handler( + fn=self._on_forward_log, + handle=mitogen.core.FORWARD_LOG, + ) def _on_forward_log(self, msg): if msg == mitogen.core._DEAD: @@ -524,7 +527,10 @@ class ModuleResponder(object): self._cache = {} # fullname -> pickled self.blacklist = [] self.whitelist = [''] - router.add_handler(self._on_get_module, mitogen.core.GET_MODULE) + router.add_handler( + fn=self._on_get_module, + handle=mitogen.core.GET_MODULE, + ) def __repr__(self): return 'ModuleResponder(%r)' % (self._router,) @@ -684,7 +690,10 @@ class IdAllocator(object): self.router = router self.next_id = 1 self.lock = threading.Lock() - router.add_handler(self.on_allocate_id, mitogen.core.ALLOCATE_ID) + router.add_handler( + fn=self.on_allocate_id, + handle=mitogen.core.ALLOCATE_ID, + ) def __repr__(self): return 'IdAllocator(%r)' % (self.router,) diff --git a/mitogen/parent.py b/mitogen/parent.py index 599a4603..41fe3676 100644 --- a/mitogen/parent.py +++ b/mitogen/parent.py @@ -78,6 +78,14 @@ def get_log_level(): return (LOG.level or logging.getLogger().level or logging.INFO) +def is_immediate_child(msg, stream): + """ + Handler policy that requires messages to arrive only from immediately + connected children. + """ + return msg.src_id == stream.remote_id + + def minimize_source(source): subber = lambda match: '""' + ('\n' * match.group(0).count('\n')) source = DOCSTRING_RE.sub(subber, source) @@ -554,11 +562,13 @@ class RouteMonitor(object): fn=self._on_add_route, handle=mitogen.core.ADD_ROUTE, persist=True, + policy=is_immediate_child, ) self.router.add_handler( fn=self._on_del_route, handle=mitogen.core.DEL_ROUTE, persist=True, + policy=is_immediate_child, ) def propagate(self, handle, target_id, name=None): @@ -795,7 +805,12 @@ class ModuleForwarder(object): self.router = router self.parent_context = parent_context self.importer = importer - router.add_handler(self._on_get_module, mitogen.core.GET_MODULE) + router.add_handler( + fn=self._on_get_module, + handle=mitogen.core.GET_MODULE, + persist=True, + policy=is_immediate_child, + ) def __repr__(self): return 'ModuleForwarder(%r)' % (self.router,) diff --git a/tests/router_test.py b/tests/router_test.py index c3d17b8b..2c0b7e60 100644 --- a/tests/router_test.py +++ b/tests/router_test.py @@ -89,6 +89,55 @@ class SourceVerifyTest(testlib.RouterMixin, unittest2.TestCase): self.assertTrue(expect in log.stop()) +class PolicyTest(testlib.RouterMixin, testlib.TestCase): + def test_allow_any(self): + # This guy gets everything. + recv = mitogen.core.Receiver(self.router) + recv.to_sender().send(123) + self.sync_with_broker() + self.assertFalse(recv.empty()) + self.assertEquals(123, recv.get().unpickle()) + + def test_refuse_all(self): + # Deliver a message locally from child2 with the correct auth_id, but + # the wrong src_id. + log = testlib.LogCapturer() + log.start() + + # This guy never gets anything. + recv = mitogen.core.Receiver( + router=self.router, + policy=(lambda msg, stream: False), + ) + + # This guy becomes the reply_to of our refused message. + reply_target = mitogen.core.Receiver(self.router) + + # Send the message. + self.router.route( + mitogen.core.Message( + dst_id=mitogen.context_id, + handle=recv.handle, + reply_to=reply_target.handle, + ) + ) + + # Wait for IO loop. + self.sync_with_broker() + + # Verify log. + expect = '%r: policy refused message: ' % (self.router,) + self.assertTrue(expect in log.stop()) + + # Verify message was not delivered. + self.assertTrue(recv.empty()) + + # Verify CallError received by reply_to target. + e = self.assertRaises(mitogen.core.CallError, + lambda: reply_target.get().unpickle()) + self.assertEquals(e[0], self.router.refused_msg) + + class CrashTest(testlib.BrokerMixin, unittest2.TestCase): # This is testing both Broker's ability to crash nicely, and Router's # ability to respond to the crash event.