Implement Importer.get_filename() and Importer.get_source()

Optional importer protocols required for Python to display annotated
tracebacks.
pull/35/head
David Wilson 8 years ago
parent 30991a6b42
commit 481ae1a933

@ -161,6 +161,7 @@ class Importer(object):
'econtext.utils',
]}
self.tls = threading.local()
self._cache = {}
def __repr__(self):
return 'Importer()'
@ -194,13 +195,18 @@ class Importer(object):
def load_module(self, fullname):
LOG.debug('Importer.load_module(%r)', fullname)
ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,))
try:
ret = self._cache[fullname]
except KeyError:
ret = self._context.enqueue_await_reply(GET_MODULE, None, (fullname,))
self._cache[fullname] = ret
if ret is None:
raise ImportError('Master does not have %r' % (fullname,))
pkg_present, path, data = ret
pkg_present = ret[0]
mod = sys.modules.setdefault(fullname, imp.new_module(fullname))
mod.__file__ = path
mod.__file__ = self.get_filename(fullname)
mod.__loader__ = self
if pkg_present is not None: # it's a package.
mod.__path__ = []
@ -208,10 +214,18 @@ class Importer(object):
self._present[fullname] = pkg_present
else:
mod.__package__ = fullname.rpartition('.')[0] or None
code = compile(zlib.decompress(data), 'master:' + path, 'exec')
code = compile(self.get_source(fullname), mod.__file__, 'exec')
exec code in vars(mod)
return mod
def get_filename(self, fullname):
if fullname in self._cache:
return 'master:' + self._cache[fullname][1]
def get_source(self, fullname):
if fullname in self._cache:
return zlib.decompress(self._cache[fullname][2])
class LogHandler(logging.Handler):
def __init__(self, context):

@ -81,6 +81,18 @@ class LoadModulePackageTest(ImporterMixin, unittest.TestCase):
mod = self.importer.load_module(self.modname)
self.assertEquals(mod.__file__, self.path)
def test_get_filename(self):
self.context.enqueue_await_reply.return_value = self.response
mod = self.importer.load_module(self.modname)
filename = mod.__loader__.get_filename(self.modname)
self.assertEquals('master:fake_pkg/__init__.py', filename)
def test_get_source(self):
self.context.enqueue_await_reply.return_value = self.response
mod = self.importer.load_module(self.modname)
source = mod.__loader__.get_source(self.modname)
self.assertEquals(source, zlib.decompress(self.data))
def test_module_loader_set(self):
self.context.enqueue_await_reply.return_value = self.response
mod = self.importer.load_module(self.modname)

Loading…
Cancel
Save