diff --git a/mitogen/core.py b/mitogen/core.py index 029d7106..6da9a6e6 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -394,7 +394,7 @@ class Importer(object): :param context: Context to communicate via. """ - def __init__(self, router, context, core_src): + def __init__(self, router, context, core_src, whitelist=(), blacklist=()): self._context = context self._present = {'mitogen': [ 'mitogen.compat', @@ -407,6 +407,15 @@ class Importer(object): 'mitogen.utils', ]} self._lock = threading.Lock() + self.whitelist = whitelist or [''] + self.blacklist = list(blacklist) + [ + # 2.x generates needless imports for 'builtins', while 3.x does the + # same for '__builtin__'. The correct one is built-in, the other + # always a negative round-trip. + 'builtins', + '__builtin__', + ] + # Presence of an entry in this map indicates in-flight GET_MODULE. self._callbacks = {} router.add_handler(self._on_load_module, LOAD_MODULE) @@ -451,12 +460,9 @@ class Importer(object): finally: del _tls.running - def _load_module_hacks(self, fullname): - if fullname in ('builtins', '__builtin__'): - # Python 2.x will generate needless imports for 'builtins', while - # Python 3.x will generate needless imports for '__builtin__'. The - # correct one is already present in sys.modules, the other is - # always a negative round-trip. + def _refuse_imports(self, fullname): + if ((not any(fullname.startswith(s) for s in self.whitelist)) or + (any(fullname.startswith(s) for s in self.blacklist))): raise ImportError('Refused') f = sys._getframe(2) @@ -515,7 +521,7 @@ class Importer(object): def load_module(self, fullname): _v and LOG.debug('Importer.load_module(%r)', fullname) - self._load_module_hacks(fullname) + self._refuse_imports(fullname) event = threading.Event() self._request_module(fullname, event.set) @@ -1260,7 +1266,7 @@ class ExternalContext(object): if debug: enable_debug_logging() - def _setup_importer(self, core_src_fd): + def _setup_importer(self, core_src_fd, whitelist, blacklist): if core_src_fd: with os.fdopen(101, 'r', 1) as fp: core_size = int(fp.readline()) @@ -1271,7 +1277,9 @@ class ExternalContext(object): else: core_src = None - self.importer = Importer(self.router, self.parent, core_src) + self.importer = Importer(self.router, self.parent, core_src, + whitelist, blacklist) + self.router.importer = self.importer sys.meta_path.append(self.importer) def _setup_package(self, context_id, parent_ids): @@ -1328,12 +1336,13 @@ class ExternalContext(object): self.dispatch_stopped = True def main(self, parent_ids, context_id, debug, profiling, log_level, - in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True): + in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True, + whitelist=(), blacklist=()): self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd) try: try: self._setup_logging(debug, log_level) - self._setup_importer(core_src_fd) + self._setup_importer(core_src_fd, whitelist, blacklist) self._setup_package(context_id, parent_ids) if setup_stdio: self._setup_stdio() @@ -1342,7 +1351,7 @@ class ExternalContext(object): sys.executable = os.environ.pop('ARGV0', sys.executable) _v and LOG.debug('Connected to %s; my ID is %r, PID is %r', - self.parent, context_id, os.getpid()) + self.parent, context_id, os.getpid()) _v and LOG.debug('Recovered sys.executable: %r', sys.executable) _profile_hook('main', self._dispatch_calls) diff --git a/mitogen/fakessh.py b/mitogen/fakessh.py index 3d974e97..26504c6c 100644 --- a/mitogen/fakessh.py +++ b/mitogen/fakessh.py @@ -341,17 +341,17 @@ def run(dest, router, args, deadline=None, econtext=None): fp.write('#!%s\n' % (sys.executable,)) fp.write(inspect.getsource(mitogen.core)) fp.write('\n') - fp.write('ExternalContext().main%r\n' % (( - parent_ids, # parent_ids - context_id, # context_id - router.debug, # debug - router.profiling, # profiling - logging.getLogger().level, # log_level - sock2.fileno(), # in_fd - sock2.fileno(), # out_fd - None, # core_src_fd - False, # setup_stdio - ),)) + fp.write('ExternalContext().main(**%r)\n' % ({ + 'parent_ids': parent_ids, + 'context_id': context_id, + 'debug': router.debug, + 'profiling': router.profiling, + 'log_level': mitogen.parent.get_log_level(), + 'in_fd': sock2.fileno(), + 'out_fd': sock2.fileno(), + 'core_src_fd': None, + 'setup_stdio': False, + },)) finally: fp.close() diff --git a/mitogen/master.py b/mitogen/master.py index bbec3b38..09982e6e 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -441,6 +441,8 @@ class ModuleResponder(object): self._router = router self._finder = ModuleFinder() self._cache = {} # fullname -> pickled + self.blacklist = [] + self.whitelist = [] router.add_handler(self._on_get_module, mitogen.core.GET_MODULE) def __repr__(self): @@ -448,6 +450,12 @@ class ModuleResponder(object): MAIN_RE = re.compile(r'^if\s+__name__\s*==\s*.__main__.\s*:', re.M) + def whitelist_prefix(self, fullname): + self.whitelist.append(fullname) + + def blacklist_prefix(self, fullname): + self.blacklist.append(fullname) + def neutralize_main(self, src): """Given the source for the __main__ module, try to find where it begins conditional execution based on a "if __name__ == '__main__'" @@ -458,6 +466,9 @@ class ModuleResponder(object): return src def _build_tuple(self, fullname): + if fullname in self._blacklist: + raise ImportError('blacklisted') + if fullname in self._cache: return self._cache[fullname] diff --git a/mitogen/parent.py b/mitogen/parent.py index be2be6bf..ec42c912 100644 --- a/mitogen/parent.py +++ b/mitogen/parent.py @@ -63,6 +63,10 @@ class Argv(object): return ' '.join(map(self.escape, self.argv)) +def get_log_level(): + return (LOG.level or logging.getLogger().level or logging.INFO) + + def minimize_source(source): subber = lambda match: '""' + ('\n' * match.group(0).count('\n')) source = DOCSTRING_RE.sub(subber, source) @@ -336,14 +340,17 @@ class Stream(mitogen.core.Stream): def get_preamble(self): parent_ids = mitogen.parent_ids[:] parent_ids.insert(0, mitogen.context_id) + source = inspect.getsource(mitogen.core) - source += '\nExternalContext().main%r\n' % (( - parent_ids, # parent_ids - self.remote_id, # context_id - self.debug, - self.profiling, - LOG.level or logging.getLogger().level or logging.INFO, - ),) + source += '\nExternalContext().main(**%r)\n' % ({ + 'parent_ids': parent_ids, + 'context_id': self.remote_id, + 'debug': self.debug, + 'profiling': self.profiling, + 'log_level': get_log_level(), + 'whitelist': self._router.get_module_whitelist(), + 'blacklist': self._router.get_module_blacklist(), + },) compressed = zlib.compress(minimize_source(source)) return str(len(compressed)) + '\n' + compressed @@ -385,6 +392,16 @@ class ChildIdAllocator(object): class Router(mitogen.core.Router): context_class = mitogen.core.Context + def get_module_blacklist(self): + if mitogen.context_id == 0: + return self.responder.blacklist + return self.importer.blacklist + + def get_module_whitelist(self): + if mitogen.context_id == 0: + return self.responder.whitelist + return self.importer.whitelist + def allocate_id(self): return self.id_allocator.allocate()