"""Import the given python module(s) and report error(s) encountered."""
from __future__ import annotations
def main():
Main program function used to isolate globals from imported code.
Changes to globals in imported modules on Python 2.x will overwrite our own globals.
import os
import sys
import types
# preload an empty ansible._vendor module to prevent use of any embedded modules during the import test
vendor_module_name = 'ansible._vendor'
vendor_module = types.ModuleType(vendor_module_name)
vendor_module.__file__ = os.path.join(os.path.sep.join(os.path.abspath(__file__).split(os.path.sep)[:-8]), 'lib/ansible/_vendor/')
vendor_module.__path__ = []
vendor_module.__package__ = vendor_module_name
sys.modules[vendor_module_name] = vendor_module
import ansible
import contextlib
import datetime
import json
import re
import runpy
import subprocess
import traceback
import warnings
ansible_path = os.path.dirname(os.path.dirname(ansible.__file__))
temp_path = os.environ['SANITY_TEMP_PATH'] + os.path.sep
external_python = os.environ.get('SANITY_EXTERNAL_PYTHON')
yaml_to_json_path = os.environ.get('SANITY_YAML_TO_JSON')
collection_full_name = os.environ.get('SANITY_COLLECTION_FULL_NAME')
collection_root = os.environ.get('ANSIBLE_COLLECTIONS_PATH')
import_type = os.environ.get('SANITY_IMPORTER_TYPE')
# noinspection PyCompatibility
from importlib import import_module
except ImportError:
def import_module(name, package=None): # type: (str, str | None) -> types.ModuleType
assert package is None
return sys.modules[name]
from io import BytesIO, TextIOWrapper
from importlib.util import spec_from_loader, module_from_spec
from importlib.machinery import SourceFileLoader, ModuleSpec # pylint: disable=unused-import
except ImportError:
has_py3_loader = False
has_py3_loader = True
if collection_full_name:
# allow importing code from collections when testing a collection
from ansible.module_utils.common.text.converters import to_bytes, to_text, to_native, text_type
# noinspection PyProtectedMember
from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder
from ansible.utils.collection_loader import _collection_finder
yaml_to_dict_cache = {}
# unique ISO date marker matching the one present in
iso_date_marker = 'isodate:f23983df-f3df-453c-9904-bcd08af468cc:'
iso_date_re = re.compile('^%s([0-9]{4})-([0-9]{2})-([0-9]{2})$' % iso_date_marker)
def parse_value(value):
"""Custom value parser for JSON deserialization that recognizes our internal ISO date format."""
if isinstance(value, text_type):
match =
if match:
value =, int(, int(
return value
def object_hook(data):
"""Object hook for custom ISO date deserialization from JSON."""
return dict((key, parse_value(value)) for key, value in data.items())
def yaml_to_dict(yaml, content_id):
Return a Python dict version of the provided YAML.
Conversion is done in a subprocess since the current Python interpreter does not have access to PyYAML.
if content_id in yaml_to_dict_cache:
return yaml_to_dict_cache[content_id]
cmd = [external_python, yaml_to_json_path]
proc = subprocess.Popen([to_bytes(c) for c in cmd], # pylint: disable=consider-using-with
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_bytes, stderr_bytes = proc.communicate(to_bytes(yaml))
if proc.returncode != 0:
raise Exception('command %s failed with return code %d: %s' % ([to_native(c) for c in cmd], proc.returncode, to_native(stderr_bytes)))
data = yaml_to_dict_cache[content_id] = json.loads(to_text(stdout_bytes), object_hook=object_hook)
return data
except Exception as ex:
raise Exception('internal importer error - failed to parse yaml: %s' % to_native(ex))
_collection_finder._meta_yml_to_dict = yaml_to_dict # pylint: disable=protected-access
collection_loader = _AnsibleCollectionFinder(paths=[collection_root])
# noinspection PyProtectedMember
collection_loader._install() # pylint: disable=protected-access
# do not support collection loading when not testing a collection
collection_loader = None
if collection_loader and import_type == 'plugin':
# do not unload ansible code for collection plugin (not module) tests
# doing so could result in the collection loader being initialized multiple times
# remove all modules under the ansible package, except the preloaded vendor module
list(map(sys.modules.pop, [m for m in sys.modules if m.partition('.')[0] == ansible.__name__ and m != vendor_module_name]))
if import_type == 'module':
# pre-load an empty ansible package to prevent unwanted code in from loading
# this more accurately reflects the environment that AnsiballZ runs modules under
# it also avoids issues with imports in the ansible package that are not allowed
ansible_module = types.ModuleType(ansible.__name__)
ansible_module.__file__ = ansible.__file__
ansible_module.__path__ = ansible.__path__
ansible_module.__package__ = ansible.__package__
sys.modules[ansible.__name__] = ansible_module
class ImporterAnsibleModuleException(Exception):
"""Exception thrown during initialization of ImporterAnsibleModule."""
class ImporterAnsibleModule:
"""Replacement for AnsibleModule to support import testing."""
def __init__(self, *args, **kwargs):
raise ImporterAnsibleModuleException()
class RestrictedModuleLoader:
"""Python module loader that restricts inappropriate imports."""
def __init__(self, path, name, restrict_to_module_paths):
self.path = path = name
self.loaded_modules = set()
self.restrict_to_module_paths = restrict_to_module_paths
def find_spec(self, fullname, path=None, target=None): # pylint: disable=unused-argument
# type: (RestrictedModuleLoader, str, list[str], types.ModuleType | None ) -> ModuleSpec | None | ImportError
"""Return the spec from the loader or None"""
loader = self._get_loader(fullname, path=path)
if loader is not None:
if has_py3_loader:
# loader is expected to be Optional[], but RestrictedModuleLoader does not inherit from
return spec_from_loader(fullname, loader) # type: ignore[arg-type]
raise ImportError("Failed to import '%s' due to a bug in ansible-test. Check importlib imports for typos." % fullname)
return None
def find_module(self, fullname, path=None):
# type: (RestrictedModuleLoader, str, list[str]) -> RestrictedModuleLoader | None
"""Return self if the given fullname is restricted, otherwise return None."""
return self._get_loader(fullname, path=path)
def _get_loader(self, fullname, path=None):
# type: (RestrictedModuleLoader, str, list[str]) -> RestrictedModuleLoader | None
"""Return self if the given fullname is restricted, otherwise return None."""
if fullname in self.loaded_modules:
return None # ignore modules that are already being loaded
if is_name_in_namepace(fullname, ['ansible']):
if not self.restrict_to_module_paths:
return None # for non-modules, everything in the ansible namespace is allowed
if fullname in ('ansible.module_utils.basic',):
return self # intercept loading so we can modify the result
if is_name_in_namepace(fullname, ['ansible.module_utils',]):
return None # module_utils and module under test are always allowed
if any(os.path.exists(candidate_path) for candidate_path in convert_ansible_name_to_absolute_paths(fullname)):
return self # restrict access to ansible files that exist
return None # ansible file does not exist, do not restrict access
if is_name_in_namepace(fullname, ['ansible_collections']):
if not collection_loader:
return self # restrict access to collections when we are not testing a collection
if not self.restrict_to_module_paths:
return None # for non-modules, everything in the ansible namespace is allowed
if is_name_in_namepace(fullname, ['ansible_collections...plugins.module_utils',]):
return None # module_utils and module under test are always allowed
if collection_loader.find_module(fullname, path):
return self # restrict access to collection files that exist
return None # collection file does not exist, do not restrict access
# not a namespace we care about
return None
def create_module(self, spec): # pylint: disable=unused-argument
# type: (RestrictedModuleLoader, ModuleSpec) -> None
"""Return None to use default module creation."""
return None
def exec_module(self, module):
# type: (RestrictedModuleLoader, types.ModuleType) -> None | ImportError
"""Execute the module if the name is ansible.module_utils.basic and otherwise raise an ImportError"""
fullname =
if fullname == 'ansible.module_utils.basic':
for path in convert_ansible_name_to_absolute_paths(fullname):
if not os.path.exists(path):
loader = SourceFileLoader(fullname, path)
spec = spec_from_loader(fullname, loader)
real_module = module_from_spec(spec)
real_module.AnsibleModule = ImporterAnsibleModule # type: ignore[attr-defined]
real_module._load_params = lambda *args, **kwargs: {} # type: ignore[attr-defined] # pylint: disable=protected-access
sys.modules[fullname] = real_module
return None
raise ImportError('could not find "%s"' % fullname)
raise ImportError('import of "%s" is not allowed in this context' % fullname)
def load_module(self, fullname):
# type: (RestrictedModuleLoader, str) -> types.ModuleType | ImportError
"""Return the module if the name is ansible.module_utils.basic and otherwise raise an ImportError."""
if fullname == 'ansible.module_utils.basic':
module = self.__load_module(fullname)
# stop Ansible module execution during AnsibleModule instantiation
module.AnsibleModule = ImporterAnsibleModule # type: ignore[attr-defined]
# no-op for _load_params since it may be called before instantiating AnsibleModule
module._load_params = lambda *args, **kwargs: {} # type: ignore[attr-defined] # pylint: disable=protected-access
return module
raise ImportError('import of "%s" is not allowed in this context' % fullname)
def __load_module(self, fullname):
# type: (RestrictedModuleLoader, str) -> types.ModuleType
"""Load the requested module while avoiding infinite recursion."""
return import_module(fullname)
def run(restrict_to_module_paths):
"""Main program function."""
base_dir = os.getcwd()
messages = set()
for path in sys.argv[1:] or
name = convert_relative_path_to_name(path)
test_python_module(path, name, base_dir, messages, restrict_to_module_paths)
if messages:
def test_python_module(path, name, base_dir, messages, restrict_to_module_paths):
"""Test the given python module by importing it.
:type path: str
:type name: str
:type base_dir: str
:type messages: set[str]
:type restrict_to_module_paths: bool
if name in sys.modules:
return # cannot be tested because it has already been loaded
is_ansible_module = (path.startswith('lib/ansible/modules/') or path.startswith('plugins/modules/')) and os.path.basename(path) != ''
run_main = is_ansible_module
if path == 'lib/ansible/modules/':
# async_wrapper is a non-standard Ansible module (does not use AnsibleModule) so we cannot test the main function
run_main = False
capture_normal = Capture()
capture_main = Capture()
run_module_ok = False
with monitor_sys_modules(path, messages):
with restrict_imports(path, name, messages, restrict_to_module_paths):
with capture_output(capture_normal):
if run_main:
run_module_ok = is_ansible_module
with monitor_sys_modules(path, messages):
with restrict_imports(path, name, messages, restrict_to_module_paths):
with capture_output(capture_main):
runpy.run_module(name, run_name='__main__', alter_sys=True)
except ImporterAnsibleModuleException:
# module instantiated AnsibleModule without raising an exception
if not run_module_ok:
if is_ansible_module:
report_message(path, 0, 0, 'module-guard', "AnsibleModule instantiation not guarded by `if __name__ == '__main__'`", messages)
report_message(path, 0, 0, 'non-module', "AnsibleModule instantiated by import of non-module", messages)
except BaseException as ex: # pylint: disable=locally-disabled, broad-except
# intentionally catch all exceptions, including calls to sys.exit
exc_type, _exc, exc_tb = sys.exc_info()
message = str(ex)
results = list(reversed(traceback.extract_tb(exc_tb)))
line = 0
offset = 0
full_path = os.path.join(base_dir, path)
base_path = base_dir + os.path.sep
source = None
# avoid line wraps in messages
message = re.sub(r'\n *', ': ', message)
for result in results:
if result[0] == full_path:
# save the line number for the file under test
line = result[1] or 0
if not source and result[0].startswith(base_path) and not result[0].startswith(temp_path):
# save the first path and line number in the traceback which is in our source tree
source = (os.path.relpath(result[0], base_path), result[1] or 0, 0)
if isinstance(ex, SyntaxError):
# SyntaxError has better information than the traceback
if ex.filename == full_path: # pylint: disable=locally-disabled, no-member
# syntax error was reported in the file under test
line = ex.lineno or 0 # pylint: disable=locally-disabled, no-member
offset = ex.offset or 0 # pylint: disable=locally-disabled, no-member
elif ex.filename.startswith(base_path) and not ex.filename.startswith(temp_path): # pylint: disable=locally-disabled, no-member
# syntax error was reported in our source tree
source = (os.path.relpath(ex.filename, base_path), ex.lineno or 0, ex.offset or 0) # pylint: disable=locally-disabled, no-member
# remove the filename and line number from the message
# either it was extracted above, or it's not really useful information
message = re.sub(r' \(.*?, line [0-9]+\)$', '', message)
if source and source[0] != path:
message += ' (at %s:%d:%d)' % (source[0], source[1], source[2])
report_message(path, line, offset, 'traceback', '%s: %s' % (exc_type.__name__, message), messages)
capture_report(path, capture_normal, messages)
capture_report(path, capture_main, messages)
def is_name_in_namepace(name, namespaces):
"""Returns True if the given name is one of the given namespaces, otherwise returns False."""
name_parts = name.split('.')
for namespace in namespaces:
namespace_parts = namespace.split('.')
length = min(len(name_parts), len(namespace_parts))
truncated_name = name_parts[0:length]
truncated_namespace = namespace_parts[0:length]
# empty parts in the namespace are treated as wildcards
# to simplify the comparison, use those empty parts to indicate the positions in the name to be empty as well
for idx, part in enumerate(truncated_namespace):
if not part:
truncated_name[idx] = part
# example: name=ansible, allowed_name=ansible.module_utils
# example:, allowed_name=ansible.module_utils
if truncated_name == truncated_namespace:
return True
return False
def check_sys_modules(path, before, messages):
"""Check for unwanted changes to sys.modules.
:type path: str
:type before: dict[str, module]
:type messages: set[str]
after = sys.modules
removed = set(before.keys()) - set(after.keys())
changed = set(key for key, value in before.items() if key in after and value != after[key])
# additions are checked by our custom PEP 302 loader, so we don't need to check them again here
for module in sorted(removed):
report_message(path, 0, 0, 'unload', 'unloading of "%s" in sys.modules is not supported' % module, messages)
for module in sorted(changed):
report_message(path, 0, 0, 'reload', 'reloading of "%s" in sys.modules is not supported' % module, messages)
def convert_ansible_name_to_absolute_paths(name):
"""Calculate the module path from the given name.
:type name: str
:rtype: list[str]
return [
os.path.join(ansible_path, name.replace('.', os.path.sep)),
os.path.join(ansible_path, name.replace('.', os.path.sep)) + '.py',
def convert_relative_path_to_name(path):
"""Calculate the module name from the given path.
:type path: str
:rtype: str
if path.endswith('/'):
clean_path = os.path.dirname(path)
clean_path = path
clean_path = os.path.splitext(clean_path)[0]
name = clean_path.replace(os.path.sep, '.')
if collection_loader:
# when testing collections the relative paths (and names) being tested are within the collection under test
name = 'ansible_collections.%s.%s' % (collection_full_name, name)
# when testing ansible all files being imported reside under the lib directory
name = name[len('lib/'):]
return name
class Capture:
"""Captured output and/or exception."""
def __init__(self):
# use buffered IO to simulate StringIO; allows Ansible's stream patching to behave without warnings
self.stdout = TextIOWrapper(BytesIO())
self.stderr = TextIOWrapper(BytesIO())
def capture_report(path, capture, messages):
"""Report on captured output.
:type path: str
:type capture: Capture
:type messages: set[str]
# since we're using buffered IO, flush before checking for data
stdout_value = capture.stdout.buffer.getvalue()
if stdout_value:
first = stdout_value.decode().strip().splitlines()[0].strip()
report_message(path, 0, 0, 'stdout', first, messages)
stderr_value = capture.stderr.buffer.getvalue()
if stderr_value:
first = stderr_value.decode().strip().splitlines()[0].strip()
report_message(path, 0, 0, 'stderr', first, messages)
def report_message(path, line, column, code, message, messages):
"""Report message if not already reported.
:type path: str
:type line: int
:type column: int
:type code: str
:type message: str
:type messages: set[str]
message = '%s:%d:%d: %s: %s' % (path, line, column, code, message)
if message not in messages:
def restrict_imports(path, name, messages, restrict_to_module_paths):
"""Restrict available imports.
:type path: str
:type name: str
:type messages: set[str]
:type restrict_to_module_paths: bool
restricted_loader = RestrictedModuleLoader(path, name, restrict_to_module_paths)
# noinspection PyTypeChecker
sys.meta_path.insert(0, restricted_loader)
if import_type == 'plugin' and not collection_loader:
from ansible.utils.collection_loader._collection_finder import _AnsibleCollectionFinder
_AnsibleCollectionFinder._remove() # pylint: disable=protected-access
if sys.meta_path[0] != restricted_loader:
report_message(path, 0, 0, 'metapath', 'changes to sys.meta_path[0] are not permitted', messages)
while restricted_loader in sys.meta_path:
# noinspection PyTypeChecker
def monitor_sys_modules(path, messages):
"""Monitor sys.modules for unwanted changes, reverting any additions made to our own namespaces."""
snapshot = sys.modules.copy()
check_sys_modules(path, snapshot, messages)
for key in set(sys.modules.keys()) - set(snapshot.keys()):
if is_name_in_namepace(key, ('ansible', 'ansible_collections')):
del sys.modules[key] # only unload our own code since we know it's native Python
def capture_output(capture):
"""Capture sys.stdout and sys.stderr.
:type capture: Capture
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = capture.stdout
sys.stderr = capture.stderr
# clear all warnings registries to make all warnings available
for module in sys.modules.values():
# noinspection PyUnresolvedReferences
except AttributeError:
with warnings.catch_warnings():
if collection_loader and import_type == 'plugin':
"AnsibleCollectionFinder has already been configured")
sys.stdout = old_stdout
sys.stderr = old_stderr
run(import_type == 'module')
if __name__ == '__main__':