diff --git a/mitogen/core.py b/mitogen/core.py index b0fb8603..4c1e049b 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -1292,12 +1292,21 @@ class Router(object): return # Perform source verification. - if stream is not None: - expected_stream = self._stream_by_id.get(msg.auth_id, - self._stream_by_id.get(mitogen.parent_id)) - if stream != expected_stream: - LOG.error('%r: bad source: got auth ID %r from %r, should be from %r', - self, msg, stream, expected_stream) + if stream: + parent = self._stream_by_id.get(mitogen.parent_id) + expect = self._stream_by_id.get(msg.auth_id, parent) + if stream != expect: + LOG.error('%r: bad auth_id: got %r via %r, not %r: %r', + self, msg.auth_id, stream, expect, msg) + return + + if msg.src_id != msg.auth_id: + expect = self._stream_by_id.get(msg.src_id, parent) + if stream != expect: + LOG.error('%r: bad src_id: got %r via %r, not %r: %r', + self, msg.src_id, stream, expect, msg) + return + if stream.auth_id is not None: msg.auth_id = stream.auth_id diff --git a/tests/router_test.py b/tests/router_test.py index 053037fb..c3d17b8b 100644 --- a/tests/router_test.py +++ b/tests/router_test.py @@ -11,6 +11,10 @@ import mitogen.master import mitogen.utils +def ping(): + return True + + @mitogen.core.takes_router def return_router_max_message_size(router): return router.max_message_size @@ -21,6 +25,70 @@ def send_n_sized_reply(sender, n): return 123 +class SourceVerifyTest(testlib.RouterMixin, unittest2.TestCase): + def setUp(self): + super(SourceVerifyTest, self).setUp() + # Create some children, ping them, and store what their messages look + # like so we can mess with them later. + self.child1 = self.router.fork() + self.child1_msg = self.child1.call_async(ping).get() + self.child1_stream = self.router._stream_by_id[self.child1.context_id] + + self.child2 = self.router.fork() + self.child2_msg = self.child2.call_async(ping).get() + self.child2_stream = self.router._stream_by_id[self.child2.context_id] + + def test_bad_auth_id(self): + # Deliver a message locally from child2, but using child1's stream. + log = testlib.LogCapturer() + log.start() + + # Used to ensure the message was dropped rather than routed after the + # error is logged. + recv = mitogen.core.Receiver(self.router) + self.child2_msg.handle = recv.handle + + self.broker.defer(self.router._async_route, + self.child2_msg, + stream=self.child1_stream) + + # Wait for IO loop to finish everything above. + self.sync_with_broker() + + # Ensure message wasn't forwarded. + self.assertTrue(recv.empty()) + + # Ensure error was logged. + expect = 'bad auth_id: got %d via' % (self.child2_msg.auth_id,) + self.assertTrue(expect in log.stop()) + + def test_bad_src_id(self): + # Deliver a message locally from child2 with the correct auth_id, but + # the wrong src_id. + log = testlib.LogCapturer() + log.start() + + # Used to ensure the message was dropped rather than routed after the + # error is logged. + recv = mitogen.core.Receiver(self.router) + self.child2_msg.handle = recv.handle + self.child2_msg.src_id = self.child1.context_id + + self.broker.defer(self.router._async_route, + self.child2_msg, + self.child2_stream) + + # Wait for IO loop to finish everything above. + self.sync_with_broker() + + # Ensure message wasn't forwarded. + self.assertTrue(recv.empty()) + + # Ensure error was lgoged. + expect = 'bad src_id: got %d via' % (self.child1_msg.src_id,) + self.assertTrue(expect in log.stop()) + + class CrashTest(testlib.BrokerMixin, unittest2.TestCase): # This is testing both Broker's ability to crash nicely, and Router's # ability to respond to the crash event. diff --git a/tests/testlib.py b/tests/testlib.py index 5d959b8d..cac5b1e9 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -11,6 +11,7 @@ import urlparse import unittest2 +import mitogen.core import mitogen.master import mitogen.utils @@ -115,6 +116,20 @@ def wait_for_port( % (host, port)) +def sync_with_broker(broker, timeout=10.0): + """ + Insert a synchronization barrier between the calling thread and the Broker + thread, ensuring it has completed at least one full IO loop before + returning. + + Used to block while asynchronous stuff (like defer()) happens on the + broker. + """ + sem = mitogen.core.Latch() + broker.defer(sem.put, None) + sem.get(timeout=10.0) + + class LogCapturer(object): def __init__(self, name=None): self.sio = StringIO.StringIO() @@ -188,6 +203,9 @@ class BrokerMixin(object): self.broker.join() super(BrokerMixin, self).tearDown() + def sync_with_broker(self): + sync_with_broker(self.broker) + class RouterMixin(BrokerMixin): router_class = mitogen.master.Router