From 481ae1a933559f47748f247872d46aa21c698054 Mon Sep 17 00:00:00 2001 From: David Wilson Date: Thu, 18 Aug 2016 17:12:43 +0100 Subject: [PATCH] Implement Importer.get_filename() and Importer.get_source() Optional importer protocols required for Python to display annotated tracebacks. --- econtext/core.py | 22 ++++++++++++++++++---- tests/importer_test.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/econtext/core.py b/econtext/core.py index d601a93d..cfb3b9b2 100644 --- a/econtext/core.py +++ b/econtext/core.py @@ -161,6 +161,7 @@ class Importer(object): 'econtext.utils', ]} self.tls = threading.local() + self._cache = {} def __repr__(self): return 'Importer()' @@ -194,13 +195,18 @@ class Importer(object): def load_module(self, fullname): LOG.debug('Importer.load_module(%r)', fullname) - ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,)) + try: + ret = self._cache[fullname] + except KeyError: + ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,)) + self._cache[fullname] = ret + if ret is None: raise ImportError('Master does not have %r' % (fullname,)) - pkg_present, path, data = ret + pkg_present = ret[0] mod = sys.modules.setdefault(fullname, imp.new_module(fullname)) - mod.__file__ = path + mod.__file__ = self.get_filename(fullname) mod.__loader__ = self if pkg_present is not None: # it's a package. mod.__path__ = [] @@ -208,10 +214,18 @@ class Importer(object): self._present[fullname] = pkg_present else: mod.__package__ = fullname.rpartition('.')[0] or None - code = compile(zlib.decompress(data), 'master:' + path, 'exec') + code = compile(self.get_source(fullname), mod.__file__, 'exec') exec code in vars(mod) return mod + def get_filename(self, fullname): + if fullname in self._cache: + return 'master:' + self._cache[fullname][1] + + def get_source(self, fullname): + if fullname in self._cache: + return zlib.decompress(self._cache[fullname][2]) + class LogHandler(logging.Handler): def __init__(self, context): diff --git a/tests/importer_test.py b/tests/importer_test.py index 647d95ff..18258094 100644 --- a/tests/importer_test.py +++ b/tests/importer_test.py @@ -81,6 +81,18 @@ class LoadModulePackageTest(ImporterMixin, unittest.TestCase): mod = self.importer.load_module(self.modname) self.assertEquals(mod.__file__, self.path) + def test_get_filename(self): + self.context.enqueue_await_reply.return_value = self.response + mod = self.importer.load_module(self.modname) + filename = mod.__loader__.get_filename(self.modname) + self.assertEquals('master:fake_pkg/__init__.py', filename) + + def test_get_source(self): + self.context.enqueue_await_reply.return_value = self.response + mod = self.importer.load_module(self.modname) + source = mod.__loader__.get_source(self.modname) + self.assertEquals(source, zlib.decompress(self.data)) + def test_module_loader_set(self): self.context.enqueue_await_reply.return_value = self.response mod = self.importer.load_module(self.modname)