diff --git a/mitogen/core.py b/mitogen/core.py index 8ba811a8..c8fab715 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -823,6 +823,11 @@ class Stream(BasicStream): #: :py:attr:`Message.auth_id` of every message received on this stream. auth_id = None + #: If not :data:`False`, indicates the stream has :attr:`auth_id` set and + #: its value is the same as :data:`mitogen.context_id` or appears in + #: :data:`mitogen.parent_ids`. + is_privileged = False + def __init__(self, router, remote_id, **kwargs): self._router = router self.remote_id = remote_id @@ -1265,6 +1270,7 @@ class IoLogger(BasicStream): class Router(object): context_class = Context max_message_size = 128 * 1048576 + unidirectional = False def __init__(self, broker): self.broker = broker @@ -1369,47 +1375,57 @@ class Router(object): except Exception: LOG.exception('%r._invoke(%r): %r crashed', self, msg, fn) - def _async_route(self, msg, stream=None): - _vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream) + def _async_route(self, msg, in_stream=None): + _vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, in_stream) if len(msg.data) > self.max_message_size: LOG.error('message too large (max %d bytes): %r', self.max_message_size, msg) return # Perform source verification. - if stream: + if in_stream: parent = self._stream_by_id.get(mitogen.parent_id) expect = self._stream_by_id.get(msg.auth_id, parent) - if stream != expect: + if in_stream != expect: LOG.error('%r: bad auth_id: got %r via %r, not %r: %r', - self, msg.auth_id, stream, expect, msg) + self, msg.auth_id, in_stream, expect, msg) return if msg.src_id != msg.auth_id: expect = self._stream_by_id.get(msg.src_id, parent) - if stream != expect: + if in_stream != expect: LOG.error('%r: bad src_id: got %r via %r, not %r: %r', - self, msg.src_id, stream, expect, msg) + self, msg.src_id, in_stream, expect, msg) return - if stream.auth_id is not None: - msg.auth_id = stream.auth_id + if in_stream.auth_id is not None: + msg.auth_id = in_stream.auth_id if msg.dst_id == mitogen.context_id: - return self._invoke(msg, stream) + return self._invoke(msg, in_stream) - stream = self._stream_by_id.get(msg.dst_id) - if stream is None: - stream = self._stream_by_id.get(mitogen.parent_id) + out_stream = self._stream_by_id.get(msg.dst_id) + if out_stream is None: + out_stream = self._stream_by_id.get(mitogen.parent_id) - if stream is None: + dead = False + if out_stream is None: LOG.error('%r: no route for %r, my ID is %r', self, msg, mitogen.context_id) + dead = True + + if in_stream and self.unidirectional and not dead and \ + not (in_stream.is_privileged or out_stream.is_privileged): + LOG.error('routing mode prevents forward of %r from %r -> %r', + msg, in_stream, out_stream) + dead = True + + if dead: if msg.reply_to and not msg.is_dead: msg.reply(Message.dead(), router=self) return - stream._send(msg) + out_stream._send(msg) def route(self, msg): self.broker.defer(self._async_route, msg) @@ -1577,14 +1593,15 @@ class ExternalContext(object): LOG.error('Stream had %d bytes after 2000ms', pending) self.broker.defer(stream.on_disconnect, self.broker) - def _setup_master(self, max_message_size, profiling, parent_id, - context_id, in_fd, out_fd): + def _setup_master(self, max_message_size, profiling, unidirectional, + parent_id, context_id, in_fd, out_fd): Router.max_message_size = max_message_size self.profiling = profiling if profiling: enable_profiling() self.broker = Broker() self.router = Router(self.broker) + self.router.undirectional = unidirectional self.router.add_handler( fn=self._on_shutdown_msg, handle=SHUTDOWN, @@ -1720,11 +1737,11 @@ class ExternalContext(object): self.dispatch_stopped = True def main(self, parent_ids, context_id, debug, profiling, log_level, - max_message_size, version, in_fd=100, out_fd=1, core_src_fd=101, - setup_stdio=True, setup_package=True, importer=None, - whitelist=(), blacklist=()): - self._setup_master(max_message_size, profiling, parent_ids[0], - context_id, in_fd, out_fd) + unidirectional, max_message_size, version, in_fd=100, out_fd=1, + core_src_fd=101, setup_stdio=True, setup_package=True, + importer=None, whitelist=(), blacklist=()): + self._setup_master(max_message_size, profiling, unidirectional, + parent_ids[0], context_id, in_fd, out_fd) try: try: self._setup_logging(debug, log_level) diff --git a/mitogen/fakessh.py b/mitogen/fakessh.py index 0e737c21..3ee91015 100644 --- a/mitogen/fakessh.py +++ b/mitogen/fakessh.py @@ -349,6 +349,7 @@ def run(dest, router, args, deadline=None, econtext=None): 'out_fd': sock2.fileno(), 'parent_ids': parent_ids, 'profiling': getattr(router, 'profiling', False), + 'unidirectional': getattr(router, 'unidirectional', False), 'setup_stdio': False, 'version': mitogen.__version__, },)) diff --git a/mitogen/fork.py b/mitogen/fork.py index 70737fc8..4a5627dc 100644 --- a/mitogen/fork.py +++ b/mitogen/fork.py @@ -90,10 +90,11 @@ class Stream(mitogen.parent.Stream): on_fork = None def construct(self, old_router, max_message_size, on_fork=None, - debug=False, profiling=False): + debug=False, profiling=False, unidirectional=False): # fork method only supports a tiny subset of options. super(Stream, self).construct(max_message_size=max_message_size, - debug=debug, profiling=profiling) + debug=debug, profiling=profiling, + unidirectional=False) self.on_fork = on_fork responder = getattr(old_router, 'responder', None) diff --git a/mitogen/master.py b/mitogen/master.py index a0c9b91b..22117a50 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -678,7 +678,6 @@ class Broker(mitogen.core.Broker): class Router(mitogen.parent.Router): broker_class = Broker - debug = False profiling = False def __init__(self, broker=None, max_message_size=None): diff --git a/mitogen/parent.py b/mitogen/parent.py index feac28a8..5786b96a 100644 --- a/mitogen/parent.py +++ b/mitogen/parent.py @@ -563,7 +563,7 @@ class Stream(mitogen.core.Stream): def construct(self, max_message_size, remote_name=None, python_path=None, debug=False, connect_timeout=None, profiling=False, - old_router=None, **kwargs): + unidirectional=False, old_router=None, **kwargs): """Get the named context running on the local machine, creating it if it does not exist.""" super(Stream, self).construct(**kwargs) @@ -585,6 +585,7 @@ class Stream(mitogen.core.Stream): self.remote_name = remote_name self.debug = debug self.profiling = profiling + self.unidirectional = unidirectional self.max_message_size = max_message_size self.connect_deadline = time.time() + self.connect_timeout @@ -709,6 +710,7 @@ class Stream(mitogen.core.Stream): 'context_id': self.remote_id, 'debug': self.debug, 'profiling': self.profiling, + 'unidirectional': self.unidirectional, 'log_level': get_log_level(), 'whitelist': self._router.get_module_whitelist(), 'blacklist': self._router.get_module_blacklist(), @@ -1021,6 +1023,7 @@ class Router(mitogen.core.Router): klass = stream_by_method_name(method_name) kwargs.setdefault('debug', self.debug) kwargs.setdefault('profiling', self.profiling) + kwargs.setdefault('unidirectional', self.unidirectional) via = kwargs.pop('via', None) if via is not None: diff --git a/mitogen/unix.py b/mitogen/unix.py index 376ddf65..8eda7692 100644 --- a/mitogen/unix.py +++ b/mitogen/unix.py @@ -89,6 +89,7 @@ class Listener(mitogen.core.BasicStream): stream.accept(sock.fileno(), sock.fileno()) stream.name = 'unix_client.%d' % (pid,) stream.auth_id = mitogen.context_id + stream.is_privileged = True self._router.register(context, stream) sock.send(struct.pack('>LLL', context_id, mitogen.context_id, os.getpid())) diff --git a/tests/router_test.py b/tests/router_test.py index b32c5e4b..01f64d87 100644 --- a/tests/router_test.py +++ b/tests/router_test.py @@ -8,6 +8,7 @@ import unittest2 import testlib import mitogen.master +import mitogen.parent import mitogen.utils @@ -15,6 +16,12 @@ def ping(): return True +@mitogen.core.takes_router +def ping_context(other, router): + other = mitogen.parent.Context(router, other.context_id) + other.call(ping) + + @mitogen.core.takes_router def return_router_max_message_size(router): return router.max_message_size @@ -50,7 +57,7 @@ class SourceVerifyTest(testlib.RouterMixin, unittest2.TestCase): self.broker.defer(self.router._async_route, self.child2_msg, - stream=self.child1_stream) + in_stream=self.child1_stream) # Wait for IO loop to finish everything above. self.sync_with_broker() @@ -270,5 +277,39 @@ class NoRouteTest(testlib.RouterMixin, testlib.TestCase): self.assertEquals(e.args[0], mitogen.core.ChannelError.local_msg) +class UnidirectionalTest(testlib.RouterMixin, testlib.TestCase): + def test_siblings_cant_talk(self): + self.router.unidirectional = True + l1 = self.router.fork() + l2 = self.router.fork() + logs = testlib.LogCapturer() + logs.start() + e = self.assertRaises(mitogen.core.CallError, + lambda: l2.call(ping_context, l1)) + + msg = 'mitogen.core.ChannelError: Channel closed by remote end.' + self.assertTrue(msg in str(e)) + self.assertTrue('routing mode prevents forward of ' in logs.stop()) + + def test_auth_id_can_talk(self): + self.router.unidirectional = True + # One stream has auth_id stamped to that of the master, so it should be + # treated like a parent. + l1 = self.router.fork() + l1s = self.router.stream_by_id(l1.context_id) + l1s.auth_id = mitogen.context_id + l1s.is_privileged = True + + l2 = self.router.fork() + logs = testlib.LogCapturer() + logs.start() + e = self.assertRaises(mitogen.core.CallError, + lambda: l2.call(ping_context, l1)) + + msg = 'mitogen.core.CallError: Refused by policy.' + self.assertTrue(msg in str(e)) + self.assertTrue('policy refused message: ' in logs.stop()) + + if __name__ == '__main__': unittest2.main()