Add maximum message size checks. Closes #151.

pull/175/head
David Wilson 8 years ago
parent e1af2db4ae
commit 1ff27ada49

@ -134,7 +134,7 @@ class MuxProcess(object):
""" """
Construct a Router, Broker, and mitogen.unix listener Construct a Router, Broker, and mitogen.unix listener
""" """
self.router = mitogen.master.Router() self.router = mitogen.master.Router(max_message_size=4096*1048576)
self.router.responder.whitelist_prefix('ansible') self.router.responder.whitelist_prefix('ansible')
self.router.responder.whitelist_prefix('ansible_mitogen') self.router.responder.whitelist_prefix('ansible_mitogen')
mitogen.core.listen(self.router.broker, 'shutdown', self.on_broker_shutdown) mitogen.core.listen(self.router.broker, 'shutdown', self.on_broker_shutdown)

@ -812,6 +812,12 @@ class Stream(BasicStream):
self._input_buf[0][:self.HEADER_LEN], self._input_buf[0][:self.HEADER_LEN],
) )
if msg_len > self._router.max_message_size:
LOG.error('Maximum message size exceeded (got %d, max %d)',
msg_len, self._router.max_message_size)
self.on_disconnect(broker)
return False
total_len = msg_len + self.HEADER_LEN total_len = msg_len + self.HEADER_LEN
if self._input_buf_len < total_len: if self._input_buf_len < total_len:
_vv and IOLOG.debug( _vv and IOLOG.debug(
@ -1191,6 +1197,7 @@ class IoLogger(BasicStream):
class Router(object): class Router(object):
context_class = Context context_class = Context
max_message_size = 128 * 1048576
def __init__(self, broker): def __init__(self, broker):
self.broker = broker self.broker = broker
@ -1274,6 +1281,11 @@ class Router(object):
def _async_route(self, msg, stream=None): def _async_route(self, msg, stream=None):
_vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream) _vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream)
if len(msg.data) > self.max_message_size:
LOG.error('message too large (max %d bytes): %r',
self.max_message_size, msg)
return
# Perform source verification. # Perform source verification.
if stream is not None: if stream is not None:
expected_stream = self._stream_by_id.get(msg.auth_id, expected_stream = self._stream_by_id.get(msg.auth_id,
@ -1438,7 +1450,9 @@ class ExternalContext(object):
_v and LOG.debug('%r: parent stream is gone, dying.', self) _v and LOG.debug('%r: parent stream is gone, dying.', self)
self.broker.shutdown() self.broker.shutdown()
def _setup_master(self, profiling, parent_id, context_id, in_fd, out_fd): def _setup_master(self, max_message_size, profiling, parent_id,
context_id, in_fd, out_fd):
Router.max_message_size = max_message_size
self.profiling = profiling self.profiling = profiling
if profiling: if profiling:
enable_profiling() enable_profiling()
@ -1571,9 +1585,11 @@ class ExternalContext(object):
self.dispatch_stopped = True self.dispatch_stopped = True
def main(self, parent_ids, context_id, debug, profiling, log_level, def main(self, parent_ids, context_id, debug, profiling, log_level,
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True, max_message_size, in_fd=100, out_fd=1, core_src_fd=101,
setup_package=True, importer=None, whitelist=(), blacklist=()): setup_stdio=True, setup_package=True, importer=None,
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd) whitelist=(), blacklist=()):
self._setup_master(max_message_size, profiling, parent_ids[0],
context_id, in_fd, out_fd)
try: try:
try: try:
self._setup_logging(debug, log_level) self._setup_logging(debug, log_level)

@ -343,14 +343,15 @@ def run(dest, router, args, deadline=None, econtext=None):
fp.write(inspect.getsource(mitogen.core)) fp.write(inspect.getsource(mitogen.core))
fp.write('\n') fp.write('\n')
fp.write('ExternalContext().main(**%r)\n' % ({ fp.write('ExternalContext().main(**%r)\n' % ({
'parent_ids': parent_ids,
'context_id': context_id, 'context_id': context_id,
'core_src_fd': None,
'debug': getattr(router, 'debug', False), 'debug': getattr(router, 'debug', False),
'profiling': getattr(router, 'profiling', False),
'log_level': mitogen.parent.get_log_level(),
'in_fd': sock2.fileno(), 'in_fd': sock2.fileno(),
'log_level': mitogen.parent.get_log_level(),
'max_message_size': router.max_message_size,
'out_fd': sock2.fileno(), 'out_fd': sock2.fileno(),
'core_src_fd': None, 'parent_ids': parent_ids,
'profiling': getattr(router, 'profiling', False),
'setup_stdio': False, 'setup_stdio': False,
},)) },))
finally: finally:

@ -85,9 +85,11 @@ class Stream(mitogen.parent.Stream):
#: User-supplied function for cleaning up child process state. #: User-supplied function for cleaning up child process state.
on_fork = None on_fork = None
def construct(self, old_router, on_fork=None, debug=False, profiling=False): def construct(self, old_router, max_message_size, on_fork=None,
debug=False, profiling=False):
# fork method only supports a tiny subset of options. # fork method only supports a tiny subset of options.
super(Stream, self).construct(debug=debug, profiling=profiling) super(Stream, self).construct(max_message_size=max_message_size,
debug=debug, profiling=profiling)
self.on_fork = on_fork self.on_fork = on_fork
responder = getattr(old_router, 'responder', None) responder = getattr(old_router, 'responder', None)

@ -646,9 +646,11 @@ class Router(mitogen.parent.Router):
debug = False debug = False
profiling = False profiling = False
def __init__(self, broker=None): def __init__(self, broker=None, max_message_size=None):
if broker is None: if broker is None:
broker = self.broker_class() broker = self.broker_class()
if max_message_size:
self.max_message_size = max_message_size
super(Router, self).__init__(broker) super(Router, self).__init__(broker)
self.upgrade() self.upgrade()

@ -337,6 +337,10 @@ class Stream(mitogen.core.Stream):
#: Set to the child's PID by connect(). #: Set to the child's PID by connect().
pid = None pid = None
#: Passed via Router wrapper methods, must eventually be passed to
#: ExternalContext.main().
max_message_size = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Stream, self).__init__(*args, **kwargs) super(Stream, self).__init__(*args, **kwargs)
self.sent_modules = set(['mitogen', 'mitogen.core']) self.sent_modules = set(['mitogen', 'mitogen.core'])
@ -344,12 +348,13 @@ class Stream(mitogen.core.Stream):
#: during disconnection. #: during disconnection.
self.routes = set([self.remote_id]) self.routes = set([self.remote_id])
def construct(self, remote_name=None, python_path=None, debug=False, def construct(self, max_message_size, remote_name=None, python_path=None,
connect_timeout=None, profiling=False, debug=False, connect_timeout=None, profiling=False,
old_router=None, **kwargs): old_router=None, **kwargs):
"""Get the named context running on the local machine, creating it if """Get the named context running on the local machine, creating it if
it does not exist.""" it does not exist."""
super(Stream, self).construct(**kwargs) super(Stream, self).construct(**kwargs)
self.max_message_size = max_message_size
if python_path: if python_path:
self.python_path = python_path self.python_path = python_path
if sys.platform == 'darwin' and self.python_path == '/usr/bin/python': if sys.platform == 'darwin' and self.python_path == '/usr/bin/python':
@ -367,6 +372,7 @@ class Stream(mitogen.core.Stream):
self.remote_name = remote_name self.remote_name = remote_name
self.debug = debug self.debug = debug
self.profiling = profiling self.profiling = profiling
self.max_message_size = max_message_size
self.connect_deadline = time.time() + self.connect_timeout self.connect_deadline = time.time() + self.connect_timeout
def on_shutdown(self, broker): def on_shutdown(self, broker):
@ -441,6 +447,7 @@ class Stream(mitogen.core.Stream):
] ]
def get_main_kwargs(self): def get_main_kwargs(self):
assert self.max_message_size is not None
parent_ids = mitogen.parent_ids[:] parent_ids = mitogen.parent_ids[:]
parent_ids.insert(0, mitogen.context_id) parent_ids.insert(0, mitogen.context_id)
return { return {
@ -451,6 +458,7 @@ class Stream(mitogen.core.Stream):
'log_level': get_log_level(), 'log_level': get_log_level(),
'whitelist': self._router.get_module_whitelist(), 'whitelist': self._router.get_module_whitelist(),
'blacklist': self._router.get_module_blacklist(), 'blacklist': self._router.get_module_blacklist(),
'max_message_size': self.max_message_size,
} }
def get_preamble(self): def get_preamble(self):
@ -703,7 +711,9 @@ class Router(mitogen.core.Router):
def _connect(self, klass, name=None, **kwargs): def _connect(self, klass, name=None, **kwargs):
context_id = self.allocate_id() context_id = self.allocate_id()
context = self.context_class(self, context_id) context = self.context_class(self, context_id)
stream = klass(self, context_id, old_router=self, **kwargs) kwargs['old_router'] = self
kwargs['max_message_size'] = self.max_message_size
stream = klass(self, context_id, **kwargs)
if name is not None: if name is not None:
stream.name = name stream.name = name
stream.connect() stream.connect()

@ -1,4 +1,6 @@
import Queue import Queue
import StringIO
import logging
import subprocess import subprocess
import time import time
@ -8,7 +10,16 @@ import testlib
import mitogen.master import mitogen.master
import mitogen.utils import mitogen.utils
mitogen.utils.log_to_file()
@mitogen.core.takes_router
def return_router_max_message_size(router):
return router.max_message_size
def send_n_sized_reply(sender, n):
sender.send(' ' * n)
return 123
class AddHandlerTest(unittest2.TestCase): class AddHandlerTest(unittest2.TestCase):
klass = mitogen.master.Router klass = mitogen.master.Router
@ -21,6 +32,44 @@ class AddHandlerTest(unittest2.TestCase):
self.assertEquals(queue.get(timeout=5), mitogen.core._DEAD) self.assertEquals(queue.get(timeout=5), mitogen.core._DEAD)
class MessageSizeTest(testlib.BrokerMixin, unittest2.TestCase):
klass = mitogen.master.Router
def test_local_exceeded(self):
router = self.klass(broker=self.broker, max_message_size=4096)
recv = mitogen.core.Receiver(router)
logs = testlib.LogCapturer()
logs.start()
sem = mitogen.core.Latch()
router.route(mitogen.core.Message.pickled(' '*8192))
router.broker.defer(sem.put, ' ') # wlil always run after _async_route
sem.get()
expect = 'message too large (max 4096 bytes)'
self.assertTrue(expect in logs.stop())
def test_remote_configured(self):
router = self.klass(broker=self.broker, max_message_size=4096)
remote = router.fork()
size = remote.call(return_router_max_message_size)
self.assertEquals(size, 4096)
def test_remote_exceeded(self):
# Ensure new contexts receive a router with the same value.
router = self.klass(broker=self.broker, max_message_size=4096)
recv = mitogen.core.Receiver(router)
logs = testlib.LogCapturer()
logs.start()
remote = router.fork()
remote.call(send_n_sized_reply, recv.to_sender(), 8192)
expect = 'message too large (max 4096 bytes)'
self.assertTrue(expect in logs.stop())
if __name__ == '__main__': if __name__ == '__main__':
unittest2.main() unittest2.main()

@ -1,4 +1,6 @@
import StringIO
import logging
import os import os
import random import random
import re import re
@ -113,6 +115,24 @@ def wait_for_port(
% (host, port)) % (host, port))
class LogCapturer(object):
def __init__(self, name=None):
self.sio = StringIO.StringIO()
self.logger = logging.getLogger(name)
self.handler = logging.StreamHandler(self.sio)
self.old_propagate = self.logger.propagate
self.old_handlers = self.logger.handlers
def start(self):
self.logger.handlers = [self.handler]
self.logger.propagate = False
def stop(self):
self.logger.handlers = self.old_handlers
self.logger.propagate = self.old_propagate
return self.sio.getvalue()
class TestCase(unittest2.TestCase): class TestCase(unittest2.TestCase):
def assertRaises(self, exc, func, *args, **kwargs): def assertRaises(self, exc, func, *args, **kwargs):
"""Like regular assertRaises, except return the exception that was """Like regular assertRaises, except return the exception that was
@ -156,19 +176,25 @@ class DockerizedSshDaemon(object):
self.container.remove() self.container.remove()
class RouterMixin(object): class BrokerMixin(object):
broker_class = mitogen.master.Broker broker_class = mitogen.master.Broker
router_class = mitogen.master.Router
def setUp(self): def setUp(self):
super(RouterMixin, self).setUp() super(BrokerMixin, self).setUp()
self.broker = self.broker_class() self.broker = self.broker_class()
self.router = self.router_class(self.broker)
def tearDown(self): def tearDown(self):
self.broker.shutdown() self.broker.shutdown()
self.broker.join() self.broker.join()
super(RouterMixin, self).tearDown() super(BrokerMixin, self).tearDown()
class RouterMixin(BrokerMixin):
router_class = mitogen.master.Router
def setUp(self):
super(RouterMixin, self).setUp()
self.router = self.router_class(self.broker)
class DockerMixin(RouterMixin): class DockerMixin(RouterMixin):

Loading…
Cancel
Save