Remove uses of assert in production code (#32079)

* Remove uses of assert in production code

* Fix assertion

* Add code smell test for assertions, currently limited to lib/ansible

* Fix assertion

* Add docs for no-assert

* Remove new assert from enos

* Fix assert in module_utils.connection
pull/14049/merge
Matt Martz 7 years ago committed by ansibot
parent 464ded80f5
commit 99d4f5bab4

@ -0,0 +1,16 @@
Sanity Tests » no-assert
========================
Do not use ``assert`` in production Ansible python code. When running Python
with optimizations, Python will remove ``assert`` statements, potentially
allowing for unexpected behavior throughout the Ansible code base.
Instead of using ``assert`` you should utilize simple ``if`` statements,
that result in raising an exception. There is a new exception called
``AnsibleAssertionError`` that inherits from ``AnsibleError`` and
``AssertionError``. When possible, utilize a more specific exception
than ``AnsibleAssertionError``.
Modules will not have access to ``AnsibleAssertionError`` and should instead
raise ``AssertionError``, a more specific exception, or just use
``module.fail_json`` at the failure point.

@ -172,6 +172,11 @@ class AnsibleError(Exception):
return error_message
class AnsibleAssertionError(AnsibleError, AssertionError):
'''Invalid assertion'''
pass
class AnsibleOptionsError(AnsibleError):
''' bad or incomplete options passed '''
pass

@ -98,7 +98,8 @@ def get_connection(module):
def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>'
if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec)
commands = transform(commands)

@ -2248,7 +2248,8 @@ class AnsibleModule(object):
def fail_json(self, **kwargs):
''' return from the module, with an error message '''
assert 'msg' in kwargs, "implementation error -- msg to explain the error is required"
if 'msg' not in kwargs:
raise AssertionError("implementation error -- msg to explain the error is required")
kwargs['failed'] = True
# add traceback if debug or high verbosity and it is missing

@ -95,7 +95,8 @@ class ConnectionError(Exception):
class Connection:
def __init__(self, socket_path):
assert socket_path is not None, 'socket_path must be a value'
if socket_path is None:
raise AssertionError('socket_path must be a value')
self.socket_path = socket_path
def __getattr__(self, name):

@ -115,7 +115,8 @@ def get_config(module, flags=None):
def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>'
if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec)
commands = transform(commands)

@ -67,7 +67,8 @@ def get_connection(module):
def to_commands(module, commands):
assert isinstance(commands, list), 'argument must be of type <list>'
if not isinstance(commands, list):
raise AssertionError('argument must be of type <list>')
transform = EntityCollection(module, command_spec)
commands = transform(commands)

@ -97,7 +97,8 @@ class ConfigLine(object):
return len(self._parents) > 0
def add_child(self, obj):
assert isinstance(obj, ConfigLine), 'child must be of type `ConfigLine`'
if not isinstance(obj, ConfigLine):
raise AssertionError('child must be of type `ConfigLine`')
self._children.append(obj)
@ -263,7 +264,8 @@ class NetworkConfig(object):
return item
def get_block(self, path):
assert isinstance(path, list), 'path argument must be a list object'
if not isinstance(path, list):
raise AssertionError('path argument must be a list object')
obj = self.get_object(path)
if not obj:
raise ValueError('path does not exist in config')

@ -222,8 +222,10 @@ def dict_diff(base, comparable):
:returns: new dict object with differences
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(comparable, dict), "`comparable` must be of type <dict>"
if not isinstance(base, dict):
raise AssertionError("`base` must be of type <dict>")
if not isinstance(comparable, dict):
raise AssertionError("`comparable` must be of type <dict>")
updates = dict()
@ -257,8 +259,10 @@ def dict_merge(base, other):
:returns: new combined dict object
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(other, dict), "`other` must be of type <dict>"
if not isinstance(base, dict):
raise AssertionError("`base` must be of type <dict>")
if not isinstance(other, dict):
raise AssertionError("`other` must be of type <dict>")
combined = dict()
@ -306,7 +310,8 @@ def conditional(expr, val, cast=None):
op, arg = match.groups()
else:
op = 'eq'
assert (' ' not in str(expr)), 'invalid expression: cannot contain spaces'
if ' ' in str(expr):
raise AssertionError('invalid expression: cannot contain spaces')
arg = expr
if cast is None and val is not None:

@ -273,7 +273,8 @@ def umc_module_for_edit(module, object_dn, superordinate=None):
def create_containers_and_parents(container_dn):
"""Create a container and if needed the parents containers"""
import univention.admin.uexceptions as uexcp
assert container_dn.startswith("cn=")
if not container_dn.startswith("cn="):
raise AssertionError()
try:
parent = ldap_dn_tree_parent(container_dn)
obj = umc_module_for_add(

@ -285,7 +285,8 @@ def check_dp_status(client, dp_id, status):
:returns: True or False
"""
assert isinstance(status, list)
if not isinstance(status, list):
raise AssertionError()
if pipeline_field(client, dp_id, field="@pipelineState") in status:
return True
else:

@ -380,7 +380,8 @@ class ClcGroup(object):
changed: Boolean- whether a change was made,
group: A clc group object for the group
"""
assert self.root_group, "Implementation Error: Root Group not set"
if not self.root_group:
raise AssertionError("Implementation Error: Root Group not set")
parent = parent_name if parent_name is not None else self.root_group.name
description = group_description
changed = False

@ -237,7 +237,8 @@ class Droplet(JsonfyMixIn):
self.update_attr(json)
def power_on(self):
assert self.status == 'off', 'Can only power on a closed one.'
if self.status != 'off':
raise AssertionError('Can only power on a closed one.')
json = self.manager.power_on_droplet(self.id)
self.update_attr(json)

@ -424,8 +424,10 @@ class PyVmomiDeviceHelper(object):
diskspec.device.backing.diskMode = 'persistent'
diskspec.device.controllerKey = scsi_ctl.device.key
assert self.next_disk_unit_number != 7
assert disk_index != 7
if self.next_disk_unit_number == 7:
raise AssertionError()
if disk_index == 7:
raise AssertionError()
"""
Configure disk unit number.
"""
@ -1127,7 +1129,8 @@ class PyVmomiHelper(PyVmomi):
return datastore, datastore_name
def obj_has_parent(self, obj, parent):
assert obj is not None and parent is not None
if obj is None and parent is None:
raise AssertionError()
current_parent = obj
while True:
@ -1573,7 +1576,7 @@ def main():
result["failed"] = False
else:
# This should not happen
assert False
raise AssertionError()
# VM doesn't exist
else:
if module.params['state'] in ['poweredon', 'poweredoff', 'present', 'restarted', 'suspended']:

@ -342,7 +342,7 @@ class PyVmomiHelper(object):
task = vm.RemoveAllSnapshots()
else:
# This should not happen
assert False
raise AssertionError()
if task:
self.wait_for_task(task)

@ -236,8 +236,9 @@ def set_acl(consul_client, configuration):
acls_as_json = decode_acls_as_json(consul_client.acl.list())
existing_acls_mapped_by_name = dict((acl.name, acl) for acl in acls_as_json if acl.name is not None)
existing_acls_mapped_by_token = dict((acl.token, acl) for acl in acls_as_json)
assert None not in existing_acls_mapped_by_token, "expecting ACL list to be associated to a token: %s" \
% existing_acls_mapped_by_token[None]
if None in existing_acls_mapped_by_token:
raise AssertionError("expecting ACL list to be associated to a token: %s" %
existing_acls_mapped_by_token[None])
if configuration.token is None and configuration.name and configuration.name in existing_acls_mapped_by_name:
# No token but name given so can get token from name
@ -246,8 +247,10 @@ def set_acl(consul_client, configuration):
if configuration.token and configuration.token in existing_acls_mapped_by_token:
return update_acl(consul_client, configuration)
else:
assert configuration.token not in existing_acls_mapped_by_token
assert configuration.name not in existing_acls_mapped_by_name
if configuration.token in existing_acls_mapped_by_token:
raise AssertionError()
if configuration.name in existing_acls_mapped_by_name:
raise AssertionError()
return create_acl(consul_client, configuration)
@ -266,7 +269,8 @@ def update_acl(consul_client, configuration):
rules_as_hcl = encode_rules_as_hcl_string(configuration.rules)
updated_token = consul_client.acl.update(
configuration.token, name=name, type=configuration.token_type, rules=rules_as_hcl)
assert updated_token == configuration.token
if updated_token != configuration.token:
raise AssertionError()
return Output(changed=changed, token=configuration.token, rules=configuration.rules, operation=UPDATE_OPERATION)
@ -379,12 +383,14 @@ def encode_rules_as_json(rules):
rules_as_json = defaultdict(dict)
for rule in rules:
if rule.pattern is not None:
assert rule.pattern not in rules_as_json[rule.scope]
if rule.pattern in rules_as_json[rule.scope]:
raise AssertionError()
rules_as_json[rule.scope][rule.pattern] = {
_POLICY_JSON_PROPERTY: rule.policy
}
else:
assert rule.scope not in rules_as_json
if rule.scope in rules_as_json:
raise AssertionError()
rules_as_json[rule.scope] = rule.policy
return rules_as_json
@ -577,7 +583,8 @@ def get_consul_client(configuration):
token = configuration.management_token
if token is None:
token = configuration.token
assert token is not None, "Expecting the management token to always be set"
if token is None:
raise AssertionError("Expecting the management token to always be set")
return consul.Consul(host=configuration.host, port=configuration.port, scheme=configuration.scheme,
verify=configuration.validate_certs, token=token)

@ -225,16 +225,14 @@ except:
# Optional, only used for XML payload
try:
import lxml.etree
assert lxml.etree # silence pyflakes
import lxml.etree # noqa
HAS_LXML_ETREE = True
except ImportError:
HAS_LXML_ETREE = False
# Optional, only used for XML payload
try:
from xmljson import cobra
assert cobra # silence pyflakes
from xmljson import cobra # noqa
HAS_XMLJSON_COBRA = True
except ImportError:
HAS_XMLJSON_COBRA = False

@ -249,9 +249,7 @@ class BalancerMember(object):
balancer_member_page = fetch_url(self.module, self.management_url)
try:
assert balancer_member_page[1]['status'] == 200
except AssertionError:
if balancer_member_page[1]['status'] != 200:
self.module.fail_json(msg="Could not get balancer_member_page, check for connectivity! " + balancer_member_page[1])
else:
try:
@ -296,9 +294,7 @@ class BalancerMember(object):
request_body = request_body + str(values_mapping[k]) + '=0'
response = fetch_url(self.module, self.management_url, data=str(request_body))
try:
assert response[1]['status'] == 200
except AssertionError:
if response[1]['status'] != 200:
self.module.fail_json(msg="Could not set the member status! " + self.host + " " + response[1]['status'])
attributes = property(get_member_attributes)
@ -323,9 +319,7 @@ class Balancer(object):
def fetch_balancer_page(self):
""" Returns the balancer management html page as a string for later parsing."""
page = fetch_url(self.module, str(self.url))
try:
assert page[1]['status'] == 200
except AssertionError:
if page[1]['status'] != 200:
self.module.fail_json(msg="Could not get balancer page! HTTP status response: " + str(page[1]['status']))
else:
content = page[0].read()
@ -343,9 +337,7 @@ class Balancer(object):
else:
for element in soup.findAll('a')[1::1]:
balancer_member_suffix = str(element.get('href'))
try:
assert balancer_member_suffix is not ''
except AssertionError:
if not balancer_member_suffix:
self.module.fail_json(msg="Argument 'balancer_member_suffix' is empty!")
else:
yield BalancerMember(str(self.base_url + balancer_member_suffix), str(self.url), self.module)

@ -19,7 +19,7 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.errors import AnsibleParserError, AnsibleError
from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils._text import to_text
from ansible.parsing.splitter import parse_kv, split_args
@ -98,7 +98,8 @@ class ModuleArgsParser:
def __init__(self, task_ds=None):
task_ds = {} if task_ds is None else task_ds
assert isinstance(task_ds, dict), "the type of 'task_ds' should be a dict, but is a %s" % type(task_ds)
if not isinstance(task_ds, dict):
raise AnsibleAssertionError("the type of 'task_ds' should be a dict, but is a %s" % type(task_ds))
self._task_ds = task_ds
def _split_module_string(self, module_string):

@ -72,7 +72,7 @@ try:
except ImportError:
pass
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible import constants as C
from ansible.module_utils.six import PY3, binary_type
# Note: on py2, this zip is izip not the list based zip() builtin
@ -787,7 +787,10 @@ class VaultEditor:
fh.write(data)
fh.write(data[:file_len % chunk_len])
assert fh.tell() == file_len # FIXME remove this assert once we have unittests to check its accuracy
# FIXME remove this assert once we have unittests to check its accuracy
if fh.tell() != file_len:
raise AnsibleAssertionError()
os.fsync(fh)
def _shred_file(self, tmp_path):

@ -16,7 +16,7 @@ from jinja2.exceptions import UndefinedError
from ansible import constants as C
from ansible.module_utils.six import iteritems, string_types, with_metaclass
from ansible.module_utils.parsing.convert_bool import boolean
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils._text import to_text, to_native
from ansible.playbook.attribute import Attribute, FieldAttribute
from ansible.parsing.dataloader import DataLoader
@ -209,7 +209,8 @@ class Base(with_metaclass(BaseMeta, object)):
def load_data(self, ds, variable_manager=None, loader=None):
''' walk the input datastructure and assign any values '''
assert ds is not None, 'ds (%s) should not be None but it is.' % ds
if ds is None:
raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds)
# cache the datastructure internally
setattr(self, '_ds', ds)
@ -547,7 +548,8 @@ class Base(with_metaclass(BaseMeta, object)):
and extended.
'''
assert isinstance(data, dict), 'data (%s) should be a dict but is a %s' % (data, type(data))
if not isinstance(data, dict):
raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
for (name, attribute) in iteritems(self._valid_attrs):
if name in data:

@ -21,7 +21,7 @@ __metaclass__ = type
import os
from ansible import constants as C
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound
from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound, AnsibleAssertionError
from ansible.module_utils.six import string_types
try:
@ -43,7 +43,8 @@ def load_list_of_blocks(ds, play, parent_block=None, role=None, task_include=Non
from ansible.playbook.task_include import TaskInclude
from ansible.playbook.role_include import IncludeRole
assert isinstance(ds, (list, type(None))), '%s should be a list or None but is %s' % (ds, type(ds))
if not isinstance(ds, (list, type(None))):
raise AnsibleAssertionError('%s should be a list or None but is %s' % (ds, type(ds)))
block_list = []
if ds:
@ -89,11 +90,13 @@ def load_list_of_tasks(ds, play, block=None, role=None, task_include=None, use_h
from ansible.playbook.handler_task_include import HandlerTaskInclude
from ansible.template import Templar
assert isinstance(ds, list), 'The ds (%s) should be a list but was a %s' % (ds, type(ds))
if not isinstance(ds, list):
raise AnsibleAssertionError('The ds (%s) should be a list but was a %s' % (ds, type(ds)))
task_list = []
for task_ds in ds:
assert isinstance(task_ds, dict), 'The ds (%s) should be a dict but was a %s' % (ds, type(ds))
if not isinstance(task_ds, dict):
AnsibleAssertionError('The ds (%s) should be a dict but was a %s' % (ds, type(ds)))
if 'block' in task_ds:
t = Block.load(
@ -345,7 +348,8 @@ def load_list_of_roles(ds, play, current_role_path=None, variable_manager=None,
# we import here to prevent a circular dependency with imports
from ansible.playbook.role.include import RoleInclude
assert isinstance(ds, list), 'ds (%s) should be a list but was a %s' % (ds, type(ds))
if not isinstance(ds, list):
raise AnsibleAssertionError('ds (%s) should be a list but was a %s' % (ds, type(ds)))
roles = []
for role_def in ds:

@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible import constants as C
from ansible.errors import AnsibleParserError
from ansible.errors import AnsibleParserError, AnsibleAssertionError
from ansible.module_utils.six import string_types
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.base import Base
@ -116,7 +116,8 @@ class Play(Base, Taggable, Become):
Adjusts play datastructure to cleanup old/legacy items
'''
assert isinstance(ds, dict), 'while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds))
if not isinstance(ds, dict):
raise AnsibleAssertionError('while preprocessing data (%s), ds should be a dict but was a %s' % (ds, type(ds)))
# The use of 'user' in the Play datastructure was deprecated to
# line up with the same change for Tasks, due to the fact that

@ -21,7 +21,7 @@ __metaclass__ = type
import os
from ansible.errors import AnsibleParserError, AnsibleError
from ansible.errors import AnsibleParserError, AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems
from ansible.parsing.splitter import split_args, parse_kv
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
@ -105,7 +105,8 @@ class PlaybookInclude(Base, Conditional, Taggable):
up with what we expect the proper attributes to be
'''
assert isinstance(ds, dict), 'ds (%s) should be a dict but was a %s' % (ds, type(ds))
if not isinstance(ds, dict):
raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds)))
# the new, cleaned datastructure, which will have legacy
# items reduced to a standard structure

@ -22,7 +22,7 @@ __metaclass__ = type
import collections
import os
from ansible.errors import AnsibleError, AnsibleParserError
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, binary_type, text_type
from ansible.playbook.attribute import FieldAttribute
from ansible.playbook.base import Base
@ -293,7 +293,8 @@ class Role(Base, Become, Conditional, Taggable):
def add_parent(self, parent_role):
''' adds a role to the list of this roles parents '''
assert isinstance(parent_role, Role)
if not isinstance(parent_role, Role):
raise AnsibleAssertionError()
if parent_role not in self._parents:
self._parents.append(parent_role)

@ -22,7 +22,7 @@ __metaclass__ = type
import os
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types
from ansible.parsing.yaml.objects import AnsibleBaseYAMLObject, AnsibleMapping
from ansible.playbook.attribute import Attribute, FieldAttribute
@ -72,7 +72,8 @@ class RoleDefinition(Base, Become, Conditional, Taggable):
if isinstance(ds, int):
ds = "%s" % ds
assert isinstance(ds, dict) or isinstance(ds, string_types) or isinstance(ds, AnsibleBaseYAMLObject)
if not isinstance(ds, dict) and not isinstance(ds, string_types) and not isinstance(ds, AnsibleBaseYAMLObject):
raise AnsibleAssertionError()
if isinstance(ds, dict):
ds = super(RoleDefinition, self).preprocess_data(ds)

@ -22,7 +22,7 @@ __metaclass__ = type
import os
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils._text import to_native
from ansible.parsing.mod_args import ModuleArgsParser
@ -167,7 +167,8 @@ class Task(Base, Conditional, Taggable, Become):
keep it short.
'''
assert isinstance(ds, dict), 'ds (%s) should be a dict but was a %s' % (ds, type(ds))
if not isinstance(ds, dict):
raise AnsibleAssertionError('ds (%s) should be a dict but was a %s' % (ds, type(ds)))
# the new, cleaned datastructure, which will have legacy
# items reduced to a standard structure suitable for the

@ -64,7 +64,7 @@ RETURN = """
import os
import sys
from ansible.module_utils.six.moves.urllib.parse import urlparse
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase
try:
@ -131,7 +131,8 @@ class LookupModule(LookupBase):
for param in params[1:]:
if param and len(param) > 0:
name, value = param.split('=')
assert name in paramvals, "%s not a valid consul lookup parameter" % name
if name not in paramvals:
raise AnsibleAssertionError("%s not a valid consul lookup parameter" % name)
paramvals[name] = value
except (ValueError, AssertionError) as e:
raise AnsibleError(e)

@ -51,7 +51,7 @@ import codecs
import csv
from collections import MutableSequence
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase
from ansible.module_utils._text import to_bytes, to_native, to_text
@ -124,7 +124,8 @@ class LookupModule(LookupBase):
try:
for param in params[1:]:
name, value = param.split('=')
assert(name in paramvals)
if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value
except (ValueError, AssertionError) as e:
raise AnsibleError(e)

@ -65,7 +65,7 @@ import re
from collections import MutableSequence
from io import StringIO
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six.moves import configparser
from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins.lookup import LookupBase
@ -129,7 +129,8 @@ class LookupModule(LookupBase):
try:
for param in params[1:]:
name, value = param.split('=')
assert(name in paramvals)
if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value
except (ValueError, AssertionError) as e:
raise AnsibleError(e)

@ -92,7 +92,7 @@ _raw:
import os
import string
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.parsing.splitter import parse_kv
from ansible.plugins.lookup import LookupBase
@ -250,7 +250,8 @@ def _format_content(password, salt, encrypt=True):
return password
# At this point, the calling code should have assured us that there is a salt value.
assert salt, '_format_content was called with encryption requested but no salt value'
if not salt:
raise AnsibleAssertionError('_format_content was called with encryption requested but no salt value')
return u'%s salt=%s' % (password, salt)

@ -82,7 +82,7 @@ import subprocess
import time
from distutils import util
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils._text import to_bytes, to_native, to_text
from ansible.utils.encrypt import random_password
from ansible.plugins.lookup import LookupBase
@ -138,7 +138,8 @@ class LookupModule(LookupBase):
try:
for param in params[1:]:
name, value = param.split('=')
assert(name in self.paramvals)
if name not in self.paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
self.paramvals[name] = value
except (ValueError, AssertionError) as e:
raise AnsibleError(e)

@ -33,7 +33,7 @@ _list:
"""
import shelve
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.plugins.lookup import LookupBase
from ansible.module_utils._text import to_bytes, to_text
@ -63,7 +63,8 @@ class LookupModule(LookupBase):
try:
for param in params:
name, value = param.split('=')
assert(name in paramvals)
if name not in paramvals:
raise AnsibleAssertionError('%s not in paramvals' % name)
paramvals[name] = value
except (ValueError, AssertionError) as e:

@ -42,7 +42,7 @@ from jinja2.runtime import Context, StrictUndefined
from jinja2.utils import concat as j2_concat
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable
from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable, AnsibleAssertionError
from ansible.module_utils.six import string_types, text_type
from ansible.module_utils._text import to_native, to_text, to_bytes
from ansible.plugins.loader import filter_loader, lookup_loader, test_loader
@ -387,7 +387,8 @@ class Templar:
are being changed.
'''
assert isinstance(variables, dict), "the type of 'variables' should be a dict but was a %s" % (type(variables))
if not isinstance(variables, dict):
raise AnsibleAssertionError("the type of 'variables' should be a dict but was a %s" % (type(variables)))
self._available_variables = variables
self._cached_result = {}

@ -8,7 +8,7 @@ import multiprocessing
import random
from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.errors import AnsibleError, AnsibleAssertionError
from ansible.module_utils.six import text_type
from ansible.module_utils._text import to_text, to_bytes
@ -67,7 +67,8 @@ def random_password(length=DEFAULT_PASSWORD_LENGTH, chars=C.DEFAULT_PASSWORD_CHA
:kwarg chars: The characters to choose from. The default is all ascii
letters, ascii digits, and these symbols ``.,:-_``
'''
assert isinstance(chars, text_type), '%s (%s) is not a text_type' % (chars, type(chars))
if not isinstance(chars, text_type):
raise AnsibleAssertionError('%s (%s) is not a text_type' % (chars, type(chars)))
random_generator = random.SystemRandom()
return u''.join(random_generator.choice(chars) for dummy in range(length))

@ -22,6 +22,7 @@ __metaclass__ = type
from collections import MutableMapping, MutableSet, MutableSequence
from ansible.errors import AnsibleAssertionError
from ansible.module_utils.six import string_types
from ansible.parsing.plugin_docs import read_docstring
from ansible.parsing.yaml.loader import AnsibleLoader
@ -59,7 +60,8 @@ def add_fragments(doc, filename):
fragment_name, fragment_var = fragment_slug, 'DOCUMENTATION'
fragment_class = fragment_loader.get(fragment_name)
assert fragment_class is not None
if fragment_class is None:
raise AnsibleAssertionError('fragment_class is None')
fragment_yaml = getattr(fragment_class, fragment_var, '{}')
fragment = AnsibleLoader(fragment_yaml, file_name=filename).get_single_data()

@ -32,7 +32,7 @@ except ImportError:
from jinja2.exceptions import UndefinedError
from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleFileNotFound, AnsibleAssertionError
from ansible.inventory.host import Host
from ansible.inventory.helpers import sort_groups, get_group_vars
from ansible.module_utils._text import to_native
@ -132,7 +132,8 @@ class VariableManager:
@extra_vars.setter
def extra_vars(self, value):
''' ensures a clean copy of the extra_vars are used to set the value '''
assert isinstance(value, MutableMapping), "the type of 'value' for extra_vars should be a MutableMapping, but is a %s" % type(value)
if not isinstance(value, MutableMapping):
raise AnsibleAssertionError("the type of 'value' for extra_vars should be a MutableMapping, but is a %s" % type(value))
self._extra_vars = value.copy()
def set_inventory(self, inventory):
@ -146,7 +147,8 @@ class VariableManager:
@options_vars.setter
def options_vars(self, value):
''' ensures a clean copy of the options_vars are used to set the value '''
assert isinstance(value, dict), "the type of 'value' for options_vars should be a dict, but is a %s" % type(value)
if not isinstance(value, dict):
raise AnsibleAssertionError("the type of 'value' for options_vars should be a dict, but is a %s" % type(value))
self._options_vars = value.copy()
def _preprocess_vars(self, a):
@ -592,7 +594,8 @@ class VariableManager:
Sets or updates the given facts for a host in the fact cache.
'''
assert isinstance(facts, dict), "the type of 'facts' to set for host_facts should be a dict but is a %s" % type(facts)
if not isinstance(facts, dict):
raise AnsibleAssertionError("the type of 'facts' to set for host_facts should be a dict but is a %s" % type(facts))
if host.name not in self._fact_cache:
self._fact_cache[host.name] = facts
@ -607,7 +610,8 @@ class VariableManager:
Sets or updates the given facts for a host in the fact cache.
'''
assert isinstance(facts, dict), "the type of 'facts' to set for nonpersistent_facts should be a dict but is a %s" % type(facts)
if not isinstance(facts, dict):
raise AnsibleAssertionError("the type of 'facts' to set for nonpersistent_facts should be a dict but is a %s" % type(facts))
if host.name not in self._nonpersistent_fact_cache:
self._nonpersistent_fact_cache[host.name] = facts

@ -0,0 +1,40 @@
#!/usr/bin/env python
from __future__ import print_function
import os
import re
import sys
from collections import defaultdict
PATH = 'lib/ansible'
ASSERT_RE = re.compile(r'.*(?<![-:a-zA-Z#][ -])\bassert\b(?!:).*')
all_matches = defaultdict(list)
for dirpath, dirnames, filenames in os.walk(PATH):
for filename in filenames:
path = os.path.join(dirpath, filename)
if not os.path.isfile(path) or not path.endswith('.py'):
continue
with open(path, 'r') as f:
for i, line in enumerate(f.readlines()):
matches = ASSERT_RE.findall(line)
if matches:
all_matches[path].append((i + 1, line.index('assert') + 1, matches))
if all_matches:
print('Use of assert in production code is not recommended.')
print('Python will remove all assert statements if run with optimizations')
print('Alternatives:')
print(' if not isinstance(value, dict):')
print(' raise AssertionError("Expected a dict for value")')
for path, matches in all_matches.items():
for line_matches in matches:
for match in line_matches[2]:
print('%s:%d:%d: %s' % ((path,) + line_matches[:2] + (match,)))
sys.exit(1)
Loading…
Cancel
Save