Add maximum message size checks. Closes #151.

pull/175/head
David Wilson 6 years ago
parent e1af2db4ae
commit 1ff27ada49

@ -134,7 +134,7 @@ class MuxProcess(object):
"""
Construct a Router, Broker, and mitogen.unix listener
"""
self.router = mitogen.master.Router()
self.router = mitogen.master.Router(max_message_size=4096*1048576)
self.router.responder.whitelist_prefix('ansible')
self.router.responder.whitelist_prefix('ansible_mitogen')
mitogen.core.listen(self.router.broker, 'shutdown', self.on_broker_shutdown)

@ -812,6 +812,12 @@ class Stream(BasicStream):
self._input_buf[0][:self.HEADER_LEN],
)
if msg_len > self._router.max_message_size:
LOG.error('Maximum message size exceeded (got %d, max %d)',
msg_len, self._router.max_message_size)
self.on_disconnect(broker)
return False
total_len = msg_len + self.HEADER_LEN
if self._input_buf_len < total_len:
_vv and IOLOG.debug(
@ -1191,6 +1197,7 @@ class IoLogger(BasicStream):
class Router(object):
context_class = Context
max_message_size = 128 * 1048576
def __init__(self, broker):
self.broker = broker
@ -1274,6 +1281,11 @@ class Router(object):
def _async_route(self, msg, stream=None):
_vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream)
if len(msg.data) > self.max_message_size:
LOG.error('message too large (max %d bytes): %r',
self.max_message_size, msg)
return
# Perform source verification.
if stream is not None:
expected_stream = self._stream_by_id.get(msg.auth_id,
@ -1438,7 +1450,9 @@ class ExternalContext(object):
_v and LOG.debug('%r: parent stream is gone, dying.', self)
self.broker.shutdown()
def _setup_master(self, profiling, parent_id, context_id, in_fd, out_fd):
def _setup_master(self, max_message_size, profiling, parent_id,
context_id, in_fd, out_fd):
Router.max_message_size = max_message_size
self.profiling = profiling
if profiling:
enable_profiling()
@ -1571,9 +1585,11 @@ class ExternalContext(object):
self.dispatch_stopped = True
def main(self, parent_ids, context_id, debug, profiling, log_level,
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True,
setup_package=True, importer=None, whitelist=(), blacklist=()):
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd)
max_message_size, in_fd=100, out_fd=1, core_src_fd=101,
setup_stdio=True, setup_package=True, importer=None,
whitelist=(), blacklist=()):
self._setup_master(max_message_size, profiling, parent_ids[0],
context_id, in_fd, out_fd)
try:
try:
self._setup_logging(debug, log_level)

@ -343,14 +343,15 @@ def run(dest, router, args, deadline=None, econtext=None):
fp.write(inspect.getsource(mitogen.core))
fp.write('\n')
fp.write('ExternalContext().main(**%r)\n' % ({
'parent_ids': parent_ids,
'context_id': context_id,
'core_src_fd': None,
'debug': getattr(router, 'debug', False),
'profiling': getattr(router, 'profiling', False),
'log_level': mitogen.parent.get_log_level(),
'in_fd': sock2.fileno(),
'log_level': mitogen.parent.get_log_level(),
'max_message_size': router.max_message_size,
'out_fd': sock2.fileno(),
'core_src_fd': None,
'parent_ids': parent_ids,
'profiling': getattr(router, 'profiling', False),
'setup_stdio': False,
},))
finally:

@ -85,9 +85,11 @@ class Stream(mitogen.parent.Stream):
#: User-supplied function for cleaning up child process state.
on_fork = None
def construct(self, old_router, on_fork=None, debug=False, profiling=False):
def construct(self, old_router, max_message_size, on_fork=None,
debug=False, profiling=False):
# fork method only supports a tiny subset of options.
super(Stream, self).construct(debug=debug, profiling=profiling)
super(Stream, self).construct(max_message_size=max_message_size,
debug=debug, profiling=profiling)
self.on_fork = on_fork
responder = getattr(old_router, 'responder', None)

@ -646,9 +646,11 @@ class Router(mitogen.parent.Router):
debug = False
profiling = False
def __init__(self, broker=None):
def __init__(self, broker=None, max_message_size=None):
if broker is None:
broker = self.broker_class()
if max_message_size:
self.max_message_size = max_message_size
super(Router, self).__init__(broker)
self.upgrade()

@ -337,6 +337,10 @@ class Stream(mitogen.core.Stream):
#: Set to the child's PID by connect().
pid = None
#: Passed via Router wrapper methods, must eventually be passed to
#: ExternalContext.main().
max_message_size = None
def __init__(self, *args, **kwargs):
super(Stream, self).__init__(*args, **kwargs)
self.sent_modules = set(['mitogen', 'mitogen.core'])
@ -344,12 +348,13 @@ class Stream(mitogen.core.Stream):
#: during disconnection.
self.routes = set([self.remote_id])
def construct(self, remote_name=None, python_path=None, debug=False,
connect_timeout=None, profiling=False,
def construct(self, max_message_size, remote_name=None, python_path=None,
debug=False, connect_timeout=None, profiling=False,
old_router=None, **kwargs):
"""Get the named context running on the local machine, creating it if
it does not exist."""
super(Stream, self).construct(**kwargs)
self.max_message_size = max_message_size
if python_path:
self.python_path = python_path
if sys.platform == 'darwin' and self.python_path == '/usr/bin/python':
@ -367,6 +372,7 @@ class Stream(mitogen.core.Stream):
self.remote_name = remote_name
self.debug = debug
self.profiling = profiling
self.max_message_size = max_message_size
self.connect_deadline = time.time() + self.connect_timeout
def on_shutdown(self, broker):
@ -441,6 +447,7 @@ class Stream(mitogen.core.Stream):
]
def get_main_kwargs(self):
assert self.max_message_size is not None
parent_ids = mitogen.parent_ids[:]
parent_ids.insert(0, mitogen.context_id)
return {
@ -451,6 +458,7 @@ class Stream(mitogen.core.Stream):
'log_level': get_log_level(),
'whitelist': self._router.get_module_whitelist(),
'blacklist': self._router.get_module_blacklist(),
'max_message_size': self.max_message_size,
}
def get_preamble(self):
@ -703,7 +711,9 @@ class Router(mitogen.core.Router):
def _connect(self, klass, name=None, **kwargs):
context_id = self.allocate_id()
context = self.context_class(self, context_id)
stream = klass(self, context_id, old_router=self, **kwargs)
kwargs['old_router'] = self
kwargs['max_message_size'] = self.max_message_size
stream = klass(self, context_id, **kwargs)
if name is not None:
stream.name = name
stream.connect()

@ -1,4 +1,6 @@
import Queue
import StringIO
import logging
import subprocess
import time
@ -8,7 +10,16 @@ import testlib
import mitogen.master
import mitogen.utils
mitogen.utils.log_to_file()
@mitogen.core.takes_router
def return_router_max_message_size(router):
return router.max_message_size
def send_n_sized_reply(sender, n):
sender.send(' ' * n)
return 123
class AddHandlerTest(unittest2.TestCase):
klass = mitogen.master.Router
@ -21,6 +32,44 @@ class AddHandlerTest(unittest2.TestCase):
self.assertEquals(queue.get(timeout=5), mitogen.core._DEAD)
class MessageSizeTest(testlib.BrokerMixin, unittest2.TestCase):
klass = mitogen.master.Router
def test_local_exceeded(self):
router = self.klass(broker=self.broker, max_message_size=4096)
recv = mitogen.core.Receiver(router)
logs = testlib.LogCapturer()
logs.start()
sem = mitogen.core.Latch()
router.route(mitogen.core.Message.pickled(' '*8192))
router.broker.defer(sem.put, ' ') # wlil always run after _async_route
sem.get()
expect = 'message too large (max 4096 bytes)'
self.assertTrue(expect in logs.stop())
def test_remote_configured(self):
router = self.klass(broker=self.broker, max_message_size=4096)
remote = router.fork()
size = remote.call(return_router_max_message_size)
self.assertEquals(size, 4096)
def test_remote_exceeded(self):
# Ensure new contexts receive a router with the same value.
router = self.klass(broker=self.broker, max_message_size=4096)
recv = mitogen.core.Receiver(router)
logs = testlib.LogCapturer()
logs.start()
remote = router.fork()
remote.call(send_n_sized_reply, recv.to_sender(), 8192)
expect = 'message too large (max 4096 bytes)'
self.assertTrue(expect in logs.stop())
if __name__ == '__main__':
unittest2.main()

@ -1,4 +1,6 @@
import StringIO
import logging
import os
import random
import re
@ -113,6 +115,24 @@ def wait_for_port(
% (host, port))
class LogCapturer(object):
def __init__(self, name=None):
self.sio = StringIO.StringIO()
self.logger = logging.getLogger(name)
self.handler = logging.StreamHandler(self.sio)
self.old_propagate = self.logger.propagate
self.old_handlers = self.logger.handlers
def start(self):
self.logger.handlers = [self.handler]
self.logger.propagate = False
def stop(self):
self.logger.handlers = self.old_handlers
self.logger.propagate = self.old_propagate
return self.sio.getvalue()
class TestCase(unittest2.TestCase):
def assertRaises(self, exc, func, *args, **kwargs):
"""Like regular assertRaises, except return the exception that was
@ -156,19 +176,25 @@ class DockerizedSshDaemon(object):
self.container.remove()
class RouterMixin(object):
class BrokerMixin(object):
broker_class = mitogen.master.Broker
router_class = mitogen.master.Router
def setUp(self):
super(RouterMixin, self).setUp()
super(BrokerMixin, self).setUp()
self.broker = self.broker_class()
self.router = self.router_class(self.broker)
def tearDown(self):
self.broker.shutdown()
self.broker.join()
super(RouterMixin, self).tearDown()
super(BrokerMixin, self).tearDown()
class RouterMixin(BrokerMixin):
router_class = mitogen.master.Router
def setUp(self):
super(RouterMixin, self).setUp()
self.router = self.router_class(self.broker)
class DockerMixin(RouterMixin):

Loading…
Cancel
Save