Merge pull request #2187 from dhozac/argument-type-check

Add type checking for module arguments, converting as much as possible
pull/2207/merge
Daniel Hokka Zakrisson 12 years ago
commit 93f02d614b

@ -182,6 +182,7 @@ class AnsibleModule(object):
if not bypass_checks:
self._check_required_arguments()
self._check_argument_values()
self._check_argument_types()
self._check_mutually_exclusive(mutually_exclusive)
self._check_required_together(required_together)
@ -542,7 +543,7 @@ class AnsibleModule(object):
if len(missing) > 0:
self.fail_json(msg="missing required arguments: %s" % ",".join(missing))
def _check_argument_types(self):
def _check_argument_values(self):
''' ensure all arguments have the requested values, and there are no stray arguments '''
for (k,v) in self.argument_spec.iteritems():
choices = v.get('choices',None)
@ -557,6 +558,45 @@ class AnsibleModule(object):
else:
self.fail_json(msg="internal error: do not know how to interpret argument_spec")
def _check_argument_types(self):
''' ensure all arguments have the requested type '''
for (k, v) in self.argument_spec.iteritems():
wanted = v.get('type', None)
if wanted is None:
continue
if k not in self.params:
continue
value = self.params[k]
is_invalid = False
if wanted == 'str':
if not isinstance(value, basestring):
self.params[k] = str(value)
elif wanted == 'list':
if not isinstance(value, list):
if isinstance(value, basestring):
self.params[k] = value.split(",")
else:
is_invalid = True
elif wanted == 'dict':
if not isinstance(value, dict):
if isinstance(value, basestring):
self.params[k] = dict([x.split("=", 1) for x in value.split(",")])
else:
is_invalid = True
elif wanted == 'bool':
if not isinstance(value, bool):
if isinstance(value, basestring):
self.params[k] = self.boolean(value)
else:
is_invalid = True
else:
self.fail_json(msg="implementation error: unknown type %s requested for %s" % (wanted, k))
if is_invalid:
self.fail_json(msg="argument %s is of invalid type: %s, required: %s" % (k, type(value), wanted))
def _set_defaults(self, pre=True):
for (k,v) in self.argument_spec.iteritems():
default = v.get('default', None)

@ -181,12 +181,12 @@ def main():
module = AnsibleModule(
argument_spec = dict(
state = dict(default='installed', choices=['installed', 'latest', 'removed', 'absent', 'present']),
update_cache = dict(default='no', choices=['yes', 'no'], aliases=['update-cache']),
purge = dict(default='no', choices=['yes', 'no']),
update_cache = dict(default='no', aliases=['update-cache'], type='bool'),
purge = dict(default='no', type='bool'),
package = dict(default=None, aliases=['pkg', 'name']),
default_release = dict(default=None, aliases=['default-release']),
install_recommends = dict(default='yes', aliases=['install-recommends'], choices=['yes', 'no']),
force = dict(default='no', choices=['yes', 'no'])
install_recommends = dict(default='yes', aliases=['install-recommends'], type='bool'),
force = dict(default='no', type='bool')
),
supports_check_mode = True
)
@ -204,7 +204,7 @@ def main():
if p['package'] is None and p['update_cache'] != 'yes':
module.fail_json(msg='pkg=name and/or update_cache=yes is required')
install_recommends = module.boolean(p['install_recommends'])
install_recommends = p['install_recommends']
cache = apt.Cache()
if p['default_release']:
@ -212,13 +212,13 @@ def main():
# reopen cache w/ modified config
cache.open(progress=None)
if module.boolean(p['update_cache']):
if p['update_cache']:
cache.update()
cache.open(progress=None)
if not p['package']:
module.exit_json(changed=False)
force_yes = module.boolean(p['force'])
force_yes = p['force']
packages = p['package'].split(',')
latest = p['state'] == 'latest'
@ -237,7 +237,7 @@ def main():
install(module, packages, cache, default_release=p['default_release'],
install_recommends=install_recommends,force=force_yes)
elif p['state'] in [ 'removed', 'absent' ]:
remove(module, packages, cache, purge = module.boolean(p['purge']))
remove(module, packages, cache, p['purge'])
# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>

@ -94,7 +94,7 @@ def main():
argument_spec = dict(
src = dict(required=True),
dest = dict(required=True),
backup=dict(default=False, choices=BOOLEANS),
backup=dict(default=False, type='bool'),
),
add_file_common_args=True
)
@ -104,7 +104,7 @@ def main():
destmd5 = None
src = os.path.expanduser(module.params['src'])
dest = os.path.expanduser(module.params['dest'])
backup = module.boolean(module.params.get('backup', False))
backup = module.params['backup']
if not os.path.exists(src):
module.fail_json(msg="Source (%s) does not exist" % src)

@ -78,16 +78,16 @@ def main():
argument_spec = dict(
src=dict(required=True),
dest=dict(required=True),
backup=dict(default=False, choices=BOOLEANS),
force = dict(default='yes', choices=BOOLEANS, aliases=['thirsty']),
backup=dict(default=False, type='bool'),
force = dict(default='yes', aliases=['thirsty'], type='bool'),
),
add_file_common_args=True,
)
src = os.path.expanduser(module.params['src'])
dest = os.path.expanduser(module.params['dest'])
backup = module.boolean(module.params.get('backup', False))
force = module.boolean(module.params['force'])
backup = module.params['backup']
force = module.params['force']
if not os.path.exists(src):
module.fail_json(msg="Source %s failed to transfer" % (src))

@ -259,17 +259,17 @@ def main():
job=dict(required=False),
cron_file=dict(required=False),
state=dict(default='present', choices=['present', 'absent']),
backup=dict(default=False, choices=BOOLEANS),
backup=dict(default=False, type='bool'),
minute=dict(default='*'),
hour=dict(default='*'),
day=dict(default='*'),
month=dict(default='*'),
weekday=dict(default='*'),
reboot=dict(required=False, default=False, choices=BOOLEANS)
reboot=dict(required=False, default=False, type='bool')
)
)
backup = module.boolean(module.params.get('backup', False))
backup = module.params['backup']
name = module.params['name']
user = module.params['user']
job = module.params['job']
@ -279,7 +279,7 @@ def main():
day = module.params['day']
month = module.params['month']
weekday = module.params['weekday']
reboot = module.boolean(module.params.get('reboot', False))
reboot = module.params['reboot']
state = module.params['state']
do_install = module.params['state'] == 'present'
changed = False

@ -82,7 +82,7 @@ def main():
arg_spec = dict(
name=dict(required=True),
virtualenv=dict(default=None, required=False),
virtualenv_site_packages=dict(default='no', choices=BOOLEANS),
virtualenv_site_packages=dict(default='no', type='bool'),
virtualenv_command=dict(default='virtualenv', required=False),
)
@ -91,7 +91,7 @@ def main():
name = module.params['name']
env = module.params['virtualenv']
easy_install = module.get_bin_path('easy_install', True, ['%s/bin' % env])
site_packages = module.boolean(module.params['virtualenv_site_packages'])
site_packages = module.params['virtualenv_site_packages']
virtualenv_command = module.params['virtualenv_command']
rc = 0

@ -140,7 +140,7 @@ def main():
argument_spec = dict(
state = dict(choices=['file','directory','link','absent'], default='file'),
path = dict(aliases=['dest', 'name'], required=True),
recurse = dict(default='no', choices=BOOLEANS)
recurse = dict(default='no', type='bool')
),
add_file_common_args=True,
supports_check_mode=True
@ -215,7 +215,7 @@ def main():
changed = True
changed = module.set_directory_attributes_if_different(file_args, changed)
recurse = module.boolean(params['recurse'])
recurse = params['recurse']
if recurse:
for root,dirs,files in os.walk( file_args['path'] ):
for dir in dirs:

@ -195,14 +195,14 @@ def main():
argument_spec = dict(
url = dict(required=True),
dest = dict(required=True),
force = dict(default='no', choices=BOOLEANS, aliases=['thirsty'])
force = dict(default='no', aliases=['thirsty'], type='bool')
),
add_file_common_args=True
)
url = module.params['url']
dest = os.path.expanduser(module.params['dest'])
force = module.boolean(module.params['force'])
force = module.params['force']
if os.path.isdir(dest):
dest = os.path.join(dest, url_filename(url))

@ -223,7 +223,7 @@ def main():
repo=dict(required=True, aliases=['name']),
version=dict(default='HEAD'),
remote=dict(default='origin'),
force=dict(default='yes', choices=['yes', 'no'], aliases=['force'])
force=dict(default='yes', type='bool')
)
)
@ -231,7 +231,7 @@ def main():
repo = module.params['repo']
version = module.params['version']
remote = module.params['remote']
force = module.boolean(module.params['force'])
force = module.params['force']
gitconfig = os.path.join(dest, '.git', 'config')

@ -204,15 +204,15 @@ def main():
repo = dict(required=True),
dest = dict(required=True),
revision = dict(default="default"),
force = dict(default='yes', choices=['yes', 'no']),
purge = dict(default='no', choices=['yes', 'no'])
force = dict(default='yes', type='bool'),
purge = dict(default='no', type='bool')
),
)
repo = module.params['repo']
dest = module.params['dest']
revision = module.params['revision']
force = module.boolean(module.params['force'])
purge = module.boolean(module.params['purge'])
force = module.params['force']
purge = module.params['purge']
hgrc = os.path.join(dest, '.hg/hgrc')
# initial states

@ -158,7 +158,7 @@ def main():
section = dict(required=True),
option = dict(required=False),
value = dict(required=False),
backup = dict(default='no', choices=BOOLEANS),
backup = dict(default='no', type='bool'),
state = dict(default='present', choices=['present', 'absent'])
),
add_file_common_args = True
@ -171,7 +171,7 @@ def main():
option = module.params['option']
value = module.params['value']
state = module.params['state']
backup = module.boolean(module.params['backup'])
backup = module.params['backup']
changed = do_ini(module, dest, section, option, value, state, backup)

@ -246,8 +246,8 @@ def main():
line=dict(aliases=['value']),
insertafter=dict(default=None),
insertbefore=dict(default=None),
create=dict(default=False, choices=BOOLEANS),
backup=dict(default=False, choices=BOOLEANS),
create=dict(default=False, type='bool'),
backup=dict(default=False, type='bool'),
),
mutually_exclusive = [['insertbefore', 'insertafter']],
add_file_common_args = True,
@ -255,8 +255,8 @@ def main():
)
params = module.params
create = module.boolean(module.params.get('create', False))
backup = module.boolean(module.params.get('backup', False))
create = module.params['create']
backup = module.params['backup']
dest = os.path.expanduser(params['dest'])
if params['state'] == 'present':

@ -134,7 +134,7 @@ def main():
module = AnsibleModule(
argument_spec = dict(
state = dict(default="installed", choices=["installed","absent"]),
update_cache = dict(default="no", choices=["yes","no"], aliases=["update-cache"]),
update_cache = dict(default="no", aliases=["update-cache"], type='bool'),
name = dict(aliases=["pkg"], required=True)))
@ -144,7 +144,7 @@ def main():
p = module.params
if module.boolean(p["update_cache"]):
if p["update_cache"]:
update_package_db(module)
pkgs = p["name"].split(",")

@ -158,9 +158,9 @@ def main():
version=dict(default=None, required=False),
requirements=dict(default=None, required=False),
virtualenv=dict(default=None, required=False),
virtualenv_site_packages=dict(default='no', choices=BOOLEANS),
virtualenv_site_packages=dict(default='no', type='bool'),
virtualenv_command=dict(default='virtualenv', required=False),
use_mirrors=dict(default='yes', choices=BOOLEANS),
use_mirrors=dict(default='yes', type='bool'),
extra_args=dict(default=None, required=False),
),
required_one_of=[['name', 'requirements']],
@ -171,7 +171,7 @@ def main():
name = module.params['name']
version = module.params['version']
requirements = module.params['requirements']
use_mirrors = module.boolean(module.params['use_mirrors'])
use_mirrors = module.params['use_mirrors']
extra_args = module.params['extra_args']
if state == 'latest' and version is not None:
@ -188,7 +188,7 @@ def main():
if env:
virtualenv = module.get_bin_path(virtualenv_command, True)
if not os.path.exists(os.path.join(env, 'bin', 'activate')):
if module.boolean(module.params['virtualenv_site_packages']):
if module.params['virtualenv_site_packages']:
cmd = '%s --system-site-packages %s' % (virtualenv, env)
else:
cmd = '%s %s' % (virtualenv, env)

@ -75,7 +75,7 @@ class RabbitMqPlugins(object):
def main():
arg_spec = dict(
names=dict(required=True, aliases=['name']),
new_only=dict(default='no', choices=BOOLEANS),
new_only=dict(default='no', type='bool'),
state=dict(default='enabled', choices=['enabled', 'disabled'])
)
module = AnsibleModule(
@ -84,7 +84,7 @@ def main():
)
names = module.params['names'].split(',')
new_only = module.boolean(module.params['new_only'])
new_only = module.params['new_only']
state = module.params['state']
rabbitmq_plugins = RabbitMqPlugins(module)

@ -178,7 +178,7 @@ def main():
configure_priv=dict(default='^$'),
write_priv=dict(default='^$'),
read_priv=dict(default='^$'),
force=dict(default='no', choices=BOOLEANS),
force=dict(default='no', type='bool'),
state=dict(default='present', choices=['present', 'absent'])
)
module = AnsibleModule(
@ -193,7 +193,7 @@ def main():
configure_priv = module.params['configure_priv']
write_priv = module.params['write_priv']
read_priv = module.params['read_priv']
force = module.boolean(module.params['force'])
force = module.params['force']
state = module.params['state']
rabbitmq_user = RabbitMqUser(module, username, password, tags, vhost, configure_priv, write_priv, read_priv)

@ -100,7 +100,7 @@ class RabbitMqVhost(object):
def main():
arg_spec = dict(
name=dict(required=True, aliases=['vhost']),
tracing=dict(default='off', choices=BOOLEANS, aliases=['trace']),
tracing=dict(default='off', aliases=['trace'], type='bool'),
state=dict(default='present', choices=['present', 'absent'])
)
@ -110,7 +110,7 @@ def main():
)
name = module.params['name']
tracing = module.boolean(module.params['tracing'])
tracing = module.params['tracing']
state = module.params['state']
rabbitmq_vhost = RabbitMqVhost(module, name, tracing)

@ -158,8 +158,8 @@ def main():
module = AnsibleModule(
argument_spec = dict(
name=dict(required=True),
persistent=dict(default='no', choices=BOOLEANS),
state=dict(required=True, choices=BOOLEANS)
persistent=dict(default='no', type='bool'),
state=dict(required=True, type='bool')
)
)
@ -173,8 +173,8 @@ def main():
module.fail_json(msg="SELinux is disabled on this host.")
name = module.params['name']
persistent = module.boolean(module.params['persistent'])
state = module.boolean(module.params['state'])
persistent = module.params['persistent']
state = module.params['state']
result = {}
result['name'] = name

@ -101,7 +101,7 @@ class Service(object):
self.name = module.params['name']
self.state = module.params['state']
self.pattern = module.params['pattern']
self.enable = module.boolean(module.params.get('enabled', None))
self.enable = module.params['enabled']
self.changed = False
self.running = None
self.action = None
@ -713,7 +713,7 @@ def main():
name = dict(required=True),
state = dict(choices=['running', 'started', 'stopped', 'restarted', 'reloaded']),
pattern = dict(required=False, default=None),
enabled = dict(choices=BOOLEANS),
enabled = dict(choices=BOOLEANS, type='bool'),
arguments = dict(aliases=['args'], default=''),
),
supports_check_mode=True

@ -129,7 +129,7 @@ def main():
dest=dict(required=True),
repo=dict(required=True, aliases=['name', 'repository']),
revision=dict(default='HEAD', aliases=['rev']),
force=dict(default='yes', choices=['yes', 'no']),
force=dict(default='yes', type='bool'),
username=dict(required=False),
password=dict(required=False),
)
@ -138,7 +138,7 @@ def main():
dest = os.path.expanduser(module.params['dest'])
repo = module.params['repo']
revision = module.params['revision']
force = module.boolean(module.params['force'])
force = module.params['force']
username = module.params['username']
password = module.params['password']

@ -208,7 +208,7 @@ def main():
value = dict(aliases=['val'], required=False),
state = dict(default='present', choices=['present', 'absent']),
checks = dict(default='both', choices=['none', 'before', 'after', 'both']),
reload = dict(default=True, choices=BOOLEANS),
reload = dict(default=True, type='bool'),
sysctl_file = dict(default='/etc/sysctl.conf')
)
)
@ -219,7 +219,7 @@ def main():
'name': module.params['name'],
'state': module.params['state'],
'checks': module.params['checks'],
'reload': module.boolean(module.params.get('reload', True)),
'reload': module.params['reload'],
'value': module.params.get('value'),
'sysctl_file': module.params['sysctl_file']
}

@ -267,9 +267,9 @@ def main():
password = dict(required=False, default=None),
body = dict(required=False, default=None),
method = dict(required=False, default='GET', choices=['GET', 'POST', 'PUT', 'HEAD', 'DELETE', 'OPTIONS']),
return_content = dict(required=False, default='no', choices=BOOLEANS),
force_basic_auth = dict(required=False, default='no', choices=BOOLEANS),
follow_redirects = dict(required=False, default='no', choices=BOOLEANS),
return_content = dict(required=False, default='no', type='bool'),
force_basic_auth = dict(required=False, default='no', type='bool'),
follow_redirects = dict(required=False, default='no', type='bool'),
creates = dict(required=False, default=None),
removes = dict(required=False, default=None),
status_code = dict(required=False, default="200"),
@ -289,9 +289,9 @@ def main():
body = module.params['body']
method = module.params['method']
dest = module.params['dest']
return_content = module.boolean(module.params['return_content'])
force_basic_auth = module.boolean(module.params['force_basic_auth'])
follow_redirects = module.boolean(module.params['follow_redirects'])
return_content = module.params['return_content']
force_basic_auth = module.params['force_basic_auth']
follow_redirects = module.params['follow_redirects']
creates = module.params['creates']
removes = module.params['removes']
status_code = module.params['status_code']

@ -202,12 +202,12 @@ class User(object):
self.home = module.params['home']
self.shell = module.params['shell']
self.password = module.params['password']
self.force = module.boolean(module.params['force'])
self.remove = module.boolean(module.params['remove'])
self.createhome = module.boolean(module.params['createhome'])
self.system = module.boolean(module.params['system'])
self.append = module.boolean(module.params['append'])
self.sshkeygen = module.boolean(module.params['generate_ssh_key'])
self.force = module.params['force']
self.remove = module.params['remove']
self.createhome = module.params['createhome']
self.system = module.params['system']
self.append = module.params['append']
self.sshkeygen = module.params['generate_ssh_key']
self.ssh_bits = module.params['ssh_key_bits']
self.ssh_type = module.params['ssh_key_type']
self.ssh_comment = module.params['ssh_key_comment']
@ -1002,15 +1002,15 @@ def main():
shell=dict(default=None),
password=dict(default=None),
# following options are specific to userdel
force=dict(default='no', choices=BOOLEANS),
remove=dict(default='no', choices=BOOLEANS),
force=dict(default='no', type='bool'),
remove=dict(default='no', type='bool'),
# following options are specific to useradd
createhome=dict(default='yes', choices=BOOLEANS),
system=dict(default='no', choices=BOOLEANS),
createhome=dict(default='yes', type='bool'),
system=dict(default='no', type='bool'),
# following options are specific to usermod
append=dict(default='no', choices=BOOLEANS),
append=dict(default='no', type='bool'),
# following are specific to ssh key generation
generate_ssh_key=dict(choices=BOOLEANS),
generate_ssh_key=dict(choices=BOOLEANS, type='bool'),
ssh_key_bits=dict(default=ssh_defaults['bits']),
ssh_key_type=dict(default=ssh_defaults['type']),
ssh_key_file=dict(default=None),

Loading…
Cancel
Save