Support using importlib on py>=3 to avoid imp deprecation (#54883)

* Support using importlib on py>=3 to avoid imp deprecation

* Add changelog fragment

* importlib coverage for py3

* Ansiballz execute should use importlib too

* recursive module_utils finder should utilize importlib too

* don't be dumb

* Fix up units

* Clean up tests

* Prefer importlib.util in plugin loader when available

* insert the module into sys.modules

* 3 before 2 for consistency

* ci_complete

* Address importlib.util.find_spec returning None
pull/55773/head
Matt Martz 6 years ago committed by GitHub
parent 6d645c127f
commit 2732cde031
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- AnsiballZ - Use ``importlib`` to load the module instead of ``imp`` on Python3+

@ -23,7 +23,6 @@ __metaclass__ = type
import ast import ast
import base64 import base64
import datetime import datetime
import imp
import json import json
import os import os
import shlex import shlex
@ -46,6 +45,15 @@ from ansible.executor import action_write_locks
from ansible.utils.display import Display from ansible.utils.display import Display
try:
import importlib.util
import importlib.machinery
imp = None
except ImportError:
import imp
# HACK: keep Python 2.6 controller tests happy in CI until they're properly split # HACK: keep Python 2.6 controller tests happy in CI until they're properly split
try: try:
from importlib import import_module from importlib import import_module
@ -144,18 +152,20 @@ def _ansiballz_main():
sys.path = [p for p in sys.path if p != scriptdir] sys.path = [p for p in sys.path if p != scriptdir]
import base64 import base64
import imp
import shutil import shutil
import tempfile import tempfile
import zipfile import zipfile
if sys.version_info < (3,): if sys.version_info < (3,):
# imp is used on Python<3
import imp
bytes = str bytes = str
MOD_DESC = ('.py', 'U', imp.PY_SOURCE) MOD_DESC = ('.py', 'U', imp.PY_SOURCE)
PY3 = False PY3 = False
else: else:
# importlib is only used on Python>=3
import importlib.util
unicode = str unicode = str
MOD_DESC = ('.py', 'r', imp.PY_SOURCE)
PY3 = True PY3 = True
ZIPDATA = """%(zipdata)s""" ZIPDATA = """%(zipdata)s"""
@ -195,8 +205,13 @@ def _ansiballz_main():
basic._ANSIBLE_ARGS = json_params basic._ANSIBLE_ARGS = json_params
%(coverage)s %(coverage)s
# Run the module! By importing it as '__main__', it thinks it is executing as a script # Run the module! By importing it as '__main__', it thinks it is executing as a script
with open(module, 'rb') as mod: if sys.version_info >= (3,):
imp.load_module('__main__', mod, module, MOD_DESC) spec = importlib.util.spec_from_file_location('__main__', module)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
with open(module, 'rb') as mod:
imp.load_module('__main__', mod, module, MOD_DESC)
# Ansible modules must exit themselves # Ansible modules must exit themselves
print('{"msg": "New-style module did not handle its own exit", "failed": true}') print('{"msg": "New-style module did not handle its own exit", "failed": true}')
@ -291,9 +306,15 @@ def _ansiballz_main():
basic._ANSIBLE_ARGS = json_params basic._ANSIBLE_ARGS = json_params
# Run the module! By importing it as '__main__', it thinks it is executing as a script # Run the module! By importing it as '__main__', it thinks it is executing as a script
import imp if PY3:
with open(script_path, 'r') as f: import importlib.util
importer = imp.load_module('__main__', f, script_path, ('.py', 'r', imp.PY_SOURCE)) spec = importlib.util.spec_from_file_location('__main__', script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
import imp
with open(script_path, 'r') as f:
imp.load_module('__main__', f, script_path, ('.py', 'r', imp.PY_SOURCE))
# Ansible modules must exit themselves # Ansible modules must exit themselves
print('{"msg": "New-style module did not handle its own exit", "failed": true}') print('{"msg": "New-style module did not handle its own exit", "failed": true}')
@ -372,7 +393,11 @@ ANSIBALLZ_COVERAGE_TEMPLATE = '''
ANSIBALLZ_COVERAGE_CHECK_TEMPLATE = ''' ANSIBALLZ_COVERAGE_CHECK_TEMPLATE = '''
try: try:
imp.find_module('coverage') if PY3:
if importlib.util.find_spec('coverage') is None:
raise ImportError
else:
imp.find_module('coverage')
except ImportError: except ImportError:
print('{"msg": "Could not find `coverage` module.", "failed": true}') print('{"msg": "Could not find `coverage` module.", "failed": true}')
sys.exit(1) sys.exit(1)
@ -488,9 +513,8 @@ class ModuleDepFinder(ast.NodeVisitor):
def _slurp(path): def _slurp(path):
if not os.path.exists(path): if not os.path.exists(path):
raise AnsibleError("imported module support code does not exist at %s" % os.path.abspath(path)) raise AnsibleError("imported module support code does not exist at %s" % os.path.abspath(path))
fd = open(path, 'rb') with open(path, 'rb') as fd:
data = fd.read() data = fd.read()
fd.close()
return data return data
@ -544,6 +568,40 @@ def _get_shebang(interpreter, task_vars, templar, args=tuple()):
return shebang, interpreter_out return shebang, interpreter_out
class ModuleInfo:
def __init__(self, name, paths):
self.py_src = False
self.pkg_dir = False
path = None
if imp is None:
self._info = info = importlib.machinery.PathFinder.find_spec(name, paths)
if info is not None:
self.py_src = os.path.splitext(info.origin)[1] in importlib.machinery.SOURCE_SUFFIXES
self.pkg_dir = info.origin.endswith('/__init__.py')
path = info.origin
else:
raise ImportError("No module named '%s'" % name)
else:
self._info = info = imp.find_module(name, paths)
self.py_src = info[2][2] == imp.PY_SOURCE
self.pkg_dir = info[2][2] == imp.PKG_DIRECTORY
if self.pkg_dir:
path = os.path.join(info[1], '__init__.py')
else:
path = info[1]
self.path = path
def get_source(self):
if imp and self.py_src:
try:
return self._info[0].read()
finally:
self._info[0].close()
return _slurp(self.path)
def recursive_finder(name, data, py_module_names, py_module_cache, zf): def recursive_finder(name, data, py_module_names, py_module_cache, zf):
""" """
Using ModuleDepFinder, make sure we have all of the module_utils files that Using ModuleDepFinder, make sure we have all of the module_utils files that
@ -575,13 +633,13 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
if py_module_name[0] == 'six': if py_module_name[0] == 'six':
# Special case the python six library because it messes up the # Special case the python six library because it messes up the
# import process in an incompatible way # import process in an incompatible way
module_info = imp.find_module('six', module_utils_paths) module_info = ModuleInfo('six', module_utils_paths)
py_module_name = ('six',) py_module_name = ('six',)
idx = 0 idx = 0
elif py_module_name[0] == '_six': elif py_module_name[0] == '_six':
# Special case the python six library because it messes up the # Special case the python six library because it messes up the
# import process in an incompatible way # import process in an incompatible way
module_info = imp.find_module('_six', [os.path.join(p, 'six') for p in module_utils_paths]) module_info = ModuleInfo('_six', [os.path.join(p, 'six') for p in module_utils_paths])
py_module_name = ('six', '_six') py_module_name = ('six', '_six')
idx = 0 idx = 0
elif py_module_name[0] == 'ansible_collections': elif py_module_name[0] == 'ansible_collections':
@ -605,8 +663,8 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
if len(py_module_name) < idx: if len(py_module_name) < idx:
break break
try: try:
module_info = imp.find_module(py_module_name[-idx], module_info = ModuleInfo(py_module_name[-idx],
[os.path.join(p, *py_module_name[:-idx]) for p in module_utils_paths]) [os.path.join(p, *py_module_name[:-idx]) for p in module_utils_paths])
break break
except ImportError: except ImportError:
continue continue
@ -647,7 +705,7 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
# imp.find_module seems to prefer to return source packages so we just # imp.find_module seems to prefer to return source packages so we just
# error out if imp.find_module returns byte compiled files (This is # error out if imp.find_module returns byte compiled files (This is
# fragile as it depends on undocumented imp.find_module behaviour) # fragile as it depends on undocumented imp.find_module behaviour)
if module_info[2][2] not in (imp.PY_SOURCE, imp.PKG_DIRECTORY): if not module_info.pkg_dir and not module_info.py_src:
msg = ['Could not find python source for imported module support code for %s. Looked for' % name] msg = ['Could not find python source for imported module support code for %s. Looked for' % name]
if idx == 2: if idx == 2:
msg.append('either %s.py or %s.py' % (py_module_name[-1], py_module_name[-2])) msg.append('either %s.py or %s.py' % (py_module_name[-1], py_module_name[-2]))
@ -665,22 +723,19 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
# We already have a file handle for the module open so it makes # We already have a file handle for the module open so it makes
# sense to read it now # sense to read it now
if py_module_name not in py_module_cache: if py_module_name not in py_module_cache:
if module_info[2][2] == imp.PKG_DIRECTORY: if module_info.pkg_dir:
# Read the __init__.py instead of the module file as this is # Read the __init__.py instead of the module file as this is
# a python package # a python package
normalized_name = py_module_name + ('__init__',) normalized_name = py_module_name + ('__init__',)
if normalized_name not in py_module_names: if normalized_name not in py_module_names:
normalized_path = os.path.join(module_info[1], '__init__.py') normalized_data = module_info.get_source()
normalized_data = _slurp(normalized_path) py_module_cache[normalized_name] = (normalized_data, module_info.path)
py_module_cache[normalized_name] = (normalized_data, normalized_path)
normalized_modules.add(normalized_name) normalized_modules.add(normalized_name)
else: else:
normalized_name = py_module_name normalized_name = py_module_name
if normalized_name not in py_module_names: if normalized_name not in py_module_names:
normalized_path = module_info[1] normalized_data = module_info.get_source()
normalized_data = module_info[0].read() py_module_cache[normalized_name] = (normalized_data, module_info.path)
module_info[0].close()
py_module_cache[normalized_name] = (normalized_data, normalized_path)
normalized_modules.add(normalized_name) normalized_modules.add(normalized_name)
# Make sure that all the packages that this module is a part of # Make sure that all the packages that this module is a part of
@ -688,10 +743,10 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
for i in range(1, len(py_module_name)): for i in range(1, len(py_module_name)):
py_pkg_name = py_module_name[:-i] + ('__init__',) py_pkg_name = py_module_name[:-i] + ('__init__',)
if py_pkg_name not in py_module_names: if py_pkg_name not in py_module_names:
pkg_dir_info = imp.find_module(py_pkg_name[-1], pkg_dir_info = ModuleInfo(py_pkg_name[-1],
[os.path.join(p, *py_pkg_name[:-1]) for p in module_utils_paths]) [os.path.join(p, *py_pkg_name[:-1]) for p in module_utils_paths])
normalized_modules.add(py_pkg_name) normalized_modules.add(py_pkg_name)
py_module_cache[py_pkg_name] = (_slurp(pkg_dir_info[1]), pkg_dir_info[1]) py_module_cache[py_pkg_name] = (pkg_dir_info.get_source(), pkg_dir_info.path)
# FIXME: Currently the AnsiBallZ wrapper monkeypatches module args into a global # FIXME: Currently the AnsiBallZ wrapper monkeypatches module args into a global
# variable in basic.py. If a module doesn't import basic.py, then the AnsiBallZ wrapper will # variable in basic.py. If a module doesn't import basic.py, then the AnsiBallZ wrapper will
@ -704,9 +759,9 @@ def recursive_finder(name, data, py_module_names, py_module_cache, zf):
# from the separate python module and mirror the args into its global variable for backwards # from the separate python module and mirror the args into its global variable for backwards
# compatibility. # compatibility.
if ('basic',) not in py_module_names: if ('basic',) not in py_module_names:
pkg_dir_info = imp.find_module('basic', module_utils_paths) pkg_dir_info = ModuleInfo('basic', module_utils_paths)
normalized_modules.add(('basic',)) normalized_modules.add(('basic',))
py_module_cache[('basic',)] = (_slurp(pkg_dir_info[1]), pkg_dir_info[1]) py_module_cache[('basic',)] = (pkg_dir_info.get_source(), pkg_dir_info.path)
# End of AnsiballZ hack # End of AnsiballZ hack
# #

@ -8,7 +8,6 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import glob import glob
import imp
import os import os
import os.path import os.path
import pkgutil import pkgutil
@ -28,6 +27,12 @@ from ansible.utils.collection_loader import AnsibleCollectionLoader, AnsibleFlat
from ansible.utils.display import Display from ansible.utils.display import Display
from ansible.utils.plugin_docs import add_fragments from ansible.utils.plugin_docs import add_fragments
try:
import importlib.util
imp = None
except ImportError:
import imp
# HACK: keep Python 2.6 controller tests happy in CI until they're properly split # HACK: keep Python 2.6 controller tests happy in CI until they're properly split
try: try:
from importlib import import_module from importlib import import_module
@ -535,9 +540,15 @@ class PluginLoader:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning) warnings.simplefilter("ignore", RuntimeWarning)
with open(to_bytes(path), 'rb') as module_file: if imp is None:
# to_native is used here because imp.load_source's path is for tracebacks and python's traceback formatting uses native strings spec = importlib.util.spec_from_file_location(to_native(full_name), to_native(path))
module = imp.load_source(to_native(full_name), to_native(path), module_file) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[full_name] = module
else:
with open(to_bytes(path), 'rb') as module_file:
# to_native is used here because imp.load_source's path is for tracebacks and python's traceback formatting uses native strings
module = imp.load_source(to_native(full_name), to_native(path), module_file)
return module return module
def _update_object(self, obj, name, path): def _update_object(self, obj, name, path):

@ -19,23 +19,18 @@
from __future__ import (absolute_import, division, print_function) from __future__ import (absolute_import, division, print_function)
__metaclass__ = type __metaclass__ = type
import imp
import pytest import pytest
import zipfile import zipfile
from collections import namedtuple from collections import namedtuple
from functools import partial from io import BytesIO
from io import BytesIO, StringIO
import ansible.errors import ansible.errors
from ansible.executor.module_common import recursive_finder from ansible.executor.module_common import recursive_finder
from ansible.module_utils.six import PY2 from ansible.module_utils.six import PY2
from ansible.module_utils.six.moves import builtins
original_find_module = imp.find_module
# These are the modules that are brought in by module_utils/basic.py This may need to be updated # These are the modules that are brought in by module_utils/basic.py This may need to be updated
# when basic.py gains new imports # when basic.py gains new imports
# We will remove these when we modify AnsiBallZ to store its args in a separate file instead of in # We will remove these when we modify AnsiBallZ to store its args in a separate file instead of in
@ -108,18 +103,6 @@ def finder_containers():
return FinderContainers(py_module_names, py_module_cache, zf) return FinderContainers(py_module_names, py_module_cache, zf)
def find_module_foo(module_utils_data, *args, **kwargs):
if args[0] == 'foo':
return (module_utils_data, '/usr/lib/python2.7/site-packages/ansible/module_utils/foo.py', ('.py', 'r', imp.PY_SOURCE))
return original_find_module(*args, **kwargs)
def find_package_foo(module_utils_data, *args, **kwargs):
if args[0] == 'foo':
return (module_utils_data, '/usr/lib/python2.7/site-packages/ansible/module_utils/foo', ('', '', imp.PKG_DIRECTORY))
return original_find_module(*args, **kwargs)
class TestRecursiveFinder(object): class TestRecursiveFinder(object):
def test_no_module_utils(self, finder_containers): def test_no_module_utils(self, finder_containers):
name = 'ping' name = 'ping'
@ -145,11 +128,15 @@ class TestRecursiveFinder(object):
def test_from_import_toplevel_package(self, finder_containers, mocker): def test_from_import_toplevel_package(self, finder_containers, mocker):
if PY2: if PY2:
module_utils_data = BytesIO(b'# License\ndef do_something():\n pass\n') module_utils_data = b'# License\ndef do_something():\n pass\n'
else: else:
module_utils_data = StringIO(u'# License\ndef do_something():\n pass\n') module_utils_data = u'# License\ndef do_something():\n pass\n'
mocker.patch('imp.find_module', side_effect=partial(find_package_foo, module_utils_data)) mi_mock = mocker.patch('ansible.executor.module_common.ModuleInfo')
mocker.patch('ansible.executor.module_common._slurp', side_effect=lambda x: b'# License\ndef do_something():\n pass\n') mi_inst = mi_mock()
mi_inst.pkg_dir = True
mi_inst.py_src = False
mi_inst.path = '/path/to/ansible/module_utils/foo/__init__.py'
mi_inst.get_source.return_value = module_utils_data
name = 'ping' name = 'ping'
data = b'#!/usr/bin/python\nfrom ansible.module_utils import foo' data = b'#!/usr/bin/python\nfrom ansible.module_utils import foo'
@ -161,20 +148,22 @@ class TestRecursiveFinder(object):
assert frozenset(finder_containers.zf.namelist()) == frozenset(('ansible/module_utils/foo/__init__.py',)).union(ONLY_BASIC_FILE) assert frozenset(finder_containers.zf.namelist()) == frozenset(('ansible/module_utils/foo/__init__.py',)).union(ONLY_BASIC_FILE)
def test_from_import_toplevel_module(self, finder_containers, mocker): def test_from_import_toplevel_module(self, finder_containers, mocker):
if PY2: module_utils_data = b'# License\ndef do_something():\n pass\n'
module_utils_data = BytesIO(b'# License\ndef do_something():\n pass\n') mi_mock = mocker.patch('ansible.executor.module_common.ModuleInfo')
else: mi_inst = mi_mock()
module_utils_data = StringIO(u'# License\ndef do_something():\n pass\n') mi_inst.pkg_dir = False
mocker.patch('imp.find_module', side_effect=partial(find_module_foo, module_utils_data)) mi_inst.py_src = True
mi_inst.path = '/path/to/ansible/module_utils/foo.py'
mi_inst.get_source.return_value = module_utils_data
name = 'ping' name = 'ping'
data = b'#!/usr/bin/python\nfrom ansible.module_utils import foo' data = b'#!/usr/bin/python\nfrom ansible.module_utils import foo'
recursive_finder(name, data, *finder_containers) recursive_finder(name, data, *finder_containers)
mocker.stopall() mocker.stopall()
assert finder_containers.py_module_names == set((('foo',),)).union(MODULE_UTILS_BASIC_IMPORTS) assert finder_containers.py_module_names == set((('foo',),)).union(ONLY_BASIC_IMPORT)
assert finder_containers.py_module_cache == {} assert finder_containers.py_module_cache == {}
assert frozenset(finder_containers.zf.namelist()) == frozenset(('ansible/module_utils/foo.py',)).union(MODULE_UTILS_BASIC_FILES) assert frozenset(finder_containers.zf.namelist()) == frozenset(('ansible/module_utils/foo.py',)).union(ONLY_BASIC_FILE)
# #
# Test importing six with many permutations because it is not a normal module # Test importing six with many permutations because it is not a normal module

Loading…
Cancel
Save