updates vyos modules to use socket connection (#21228)

* updates all vyos modules to use socket connection
* adds vyos local action handler
* adds exec_command() to vyos
* updates vyos_config local action
* update unit test cases
* add base class for testing vyos modules
pull/21343/head
Peter Sprygada 8 years ago committed by Nathaniel Case
parent 85194234ba
commit 8adb108aa9

@ -25,16 +25,36 @@
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from ansible.module_utils.basic import env_fallback
from ansible.module_utils.network_common import to_list
from ansible.module_utils.connection import exec_command
_DEVICE_CONFIGS = {}
vyos_argument_spec = {
'host': dict(),
'port': dict(type='int'),
'username': dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])),
'password': dict(fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD']), no_log=True),
'ssh_keyfile': dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'),
'timeout': dict(type='int', default=10),
'provider': dict(type='dict'),
}
def check_args(module, warnings):
provider = module.params['provider'] or {}
for key in vyos_argument_spec:
if key != 'provider' and module.params[key]:
warnings.append('argument %s has been deprecated and will be '
'removed in a future version' % key)
def get_config(module, target='commands'):
cmd = ' '.join(['show configuration', target])
try:
return _DEVICE_CONFIGS[cmd]
except KeyError:
rc, out, err = module.exec_command(cmd)
rc, out, err = exec_command(module, cmd)
if rc != 0:
module.fail_json(msg='unable to retrieve current config', stderr=err)
cfg = str(out).strip()
@ -43,46 +63,42 @@ def get_config(module, target='commands'):
def run_commands(module, commands, check_rc=True):
responses = list()
for cmd in to_list(commands):
rc, out, err = module.exec_command(cmd)
rc, out, err = exec_command(module, cmd)
if check_rc and rc != 0:
module.fail_json(msg=err, rc=rc)
responses.append(out)
return responses
def load_config(module, commands, commit=False, comment=None, save=False):
rc, out, err = module.exec_command('configure')
def load_config(module, commands, commit=False, comment=None):
rc, out, err = exec_command(module, 'configure')
if rc != 0:
module.fail_json(msg='unable to enter configuration mode', output=err)
for cmd in to_list(commands):
rc, out, err = module.exec_command(cmd, check_rc=False)
rc, out, err = exec_command(module, cmd, check_rc=False)
if rc != 0:
# discard any changes in case of failure
module.exec_command('exit discard')
exec_command(module, 'exit discard')
module.fail_json(msg='configuration failed')
diff = None
if module._diff:
rc, out, err = module.exec_command('compare')
rc, out, err = exec_command(module, 'compare')
if not out.startswith('No changes'):
rc, out, err = module.exec_command('show')
rc, out, err = exec_command(module, 'show')
diff = str(out).strip()
if commit:
cmd = 'commit'
if comment:
cmd += ' comment "%s"' % comment
module.exec_command(cmd)
if save:
module.exec_command(cmd)
exec_command(module, cmd)
if not commit:
module.exec_command('exit discard')
exec_command(module, 'exit discard')
else:
module.exec_command('exit')
exec_command(module, 'exit')
if diff:
return diff

@ -1,161 +0,0 @@
# This code is part of Ansible, but is an independent component.
# This particular file snippet, and this file snippet only, is BSD licensed.
# Modules you write using this snippet, which is embedded dynamically by Ansible
# still belong to the author of the module, and may assign their own license
# to the complete work.
#
# (c) 2017 Red Hat, Inc.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import re
from ansible.module_utils.shell import CliBase
from ansible.module_utils.basic import env_fallback, get_exception
from ansible.module_utils.network_common import to_list
from ansible.module_utils.netcli import Command
from ansible.module_utils.six import iteritems
from ansible.module_utils.network import NetworkError
_DEVICE_CONFIGS = {}
_DEVICE_CONNECTION = None
vyos_cli_argument_spec = {
'host': dict(),
'port': dict(type='int'),
'username': dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])),
'password': dict(fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD']), no_log=True),
'authorize': dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'),
'auth_pass': dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])),
'timeout': dict(type='int', default=10),
'provider': dict(type='dict'),
# deprecated in Ansible 2.3
'transport': dict(),
}
def check_args(module):
provider = module.params['provider'] or {}
for key in ('host', 'username', 'password'):
if not module.params[key] and not provider.get(key):
module.fail_json(msg='missing required argument %s' % key)
class Cli(CliBase):
CLI_PROMPTS_RE = [
re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
re.compile(r"\@[\w\-\.]+:\S+?[>#\$] ?$")
]
CLI_ERRORS_RE = [
re.compile(r"\n\s*Invalid command:"),
re.compile(r"\nCommit failed"),
re.compile(r"\n\s+Set failed"),
]
def __init__(self, module):
self._module = module
super(Cli, self).__init__()
provider = self._module.params.get('provider') or dict()
for key, value in iteritems(provider):
if key in nxos_cli_argument_spec:
if self._module.params.get(key) is None and value is not None:
self._module.params[key] = value
try:
self.connect()
except NetworkError:
exc = get_exception()
self._module.fail_json(msg=str(exc))
def connect(self, params, **kwargs):
super(Cli, self).connect(params, kickstart=False, **kwargs)
self.shell.send('set terminal length 0')
def connection(module):
global _DEVICE_CONNECTION
if not _DEVICE_CONNECTION:
cli = Cli(module)
_DEVICE_CONNECTION = cli
return _DEVICE_CONNECTION
def get_config(module, target='commands'):
cmd = ' '.join(['show configuration', target])
try:
return _DEVICE_CONFIGS[cmd]
except KeyError:
conn = connection(module)
rc, out, err = conn.exec_command(cmd)
if rc != 0:
module.fail_json(msg='unable to retrieve current config', stderr=err)
cfg = str(out).strip()
_DEVICE_CONFIGS[cmd] = cfg
return cfg
def run_commands(module, commands, check_rc=True):
responses = list()
for cmd in to_list(commands):
conn = connection(module)
rc, out, err = conn.exec_command(cmd)
if check_rc and rc != 0:
module.fail_json(msg=err, rc=rc)
responses.append(out)
return responses
def load_config(module, commands, commit=False, comment=None, save=False):
commands.insert(0, 'configure')
for cmd in to_list(commands):
conn = connection(module)
rc, out, err = conn.exec_command(cmd, check_rc=False)
if rc != 0:
# discard any changes in case of failure
conn.exec_command('exit discard')
module.fail_json(msg='configuration failed')
diff = None
if module._diff:
rc, out, err = conn.exec_command('compare')
if not out.startswith('No changes'):
rc, out, err = conn.exec_command('show')
diff = str(out).strip()
if commit:
cmd = 'commit'
if comment:
cmd += ' comment "%s"' % comment
conn.exec_command(cmd)
if save:
conn.exec_command(cmd)
if not commit:
conn.exec_command('exit discard')
else:
conn.exec_command('exit')
if diff:
return diff

@ -130,32 +130,15 @@ warnings:
returned: always
type: list
sample: ['...', '...']
start:
description: The time the job started
returned: always
type: str
sample: "2016-11-16 10:38:15.126146"
end:
description: The time the job ended
returned: always
type: str
sample: "2016-11-16 10:38:25.595612"
delta:
description: The time elapsed to perform all operations
returned: always
type: str
sample: "0:00:10.469466"
"""
import time
from ansible.module_utils.local import LocalAnsibleModule
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.netcli import Conditional
from ansible.module_utils.network_common import ComplexList
from ansible.module_utils.six import string_types
from ansible.module_utils.vyos import run_commands
VALID_KEYS = ['command', 'output', 'prompt', 'response']
from ansible.module_utils.vyos import vyos_argument_spec, check_args
def to_lines(stdout):
for item in stdout:
@ -170,17 +153,13 @@ def parse_commands(module, warnings):
prompt=dict(),
response=dict(),
))
commands = command(module.params['commands'])
for index, cmd in enumerate(commands):
if module.check_mode and not cmd['command'].startswith('show'):
warnings.append('only show commands are supported when using '
'check mode, not executing `%s`' % cmd['command'])
else:
if cmd['command'].startswith('conf'):
module.fail_json(msg='vyos_command does not support running '
'config mode commands. Please use '
'vyos_config instead')
commands[index] = module.jsonify(cmd)
return commands
@ -188,7 +167,6 @@ def parse_commands(module, warnings):
def main():
spec = dict(
# { command: <str>, output: <str>, prompt: <str>, response: <str> }
commands=dict(type='list', required=True),
wait_for=dict(type='list', aliases=['waitfor']),
@ -198,10 +176,13 @@ def main():
interval=dict(default=1, type='int')
)
module = LocalAnsibleModule(argument_spec=spec, supports_check_mode=True)
spec.update(vyos_argument_spec)
module = AnsibleModule(argument_spec=spec, supports_check_mode=True)
warnings = list()
check_args(module, warnings)
commands = parse_commands(module, warnings)
wait_for = module.params['wait_for'] or list()

@ -121,27 +121,13 @@ filtered:
returned: always
type: list
sample: ['...', '...']
start:
description: The time the job started
returned: always
type: str
sample: "2016-11-16 10:38:15.126146"
end:
description: The time the job ended
returned: always
type: str
sample: "2016-11-16 10:38:25.595612"
delta:
description: The time elapsed to perform all operations
returned: always
type: str
sample: "0:00:10.469466"
"""
import re
from ansible.module_utils.local import LocalAnsibleModule
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.netcfg import NetworkConfig
from ansible.module_utils.vyos import load_config, get_config, run_commands
from ansible.module_utils.vyos import vyos_argument_spec, check_args
DEFAULT_COMMENT = 'configured by vyos_config'
@ -262,15 +248,20 @@ def main():
save=dict(type='bool', default=False),
)
argument_spec.update(vyos_argument_spec)
mutually_exclusive = [('lines', 'src')]
module = LocalAnsibleModule(
module = AnsibleModule(
argument_spec=argument_spec,
mutually_exclusive=mutually_exclusive,
supports_check_mode=True
)
result = dict(changed=False, warnings=[])
warnings = list()
check_args(module, warnings)
result = dict(changed=False, warnings=warnings)
if module.params['backup']:
result['__backup__'] = get_config(module=module)

@ -96,9 +96,10 @@ ansible_net_gather_subset:
"""
import re
from ansible.module_utils.local import LocalAnsibleModule
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.six import iteritems
from ansible.module_utils.vyos import run_commands
from ansible.module_utils.vyos import vyos_argument_spec, check_args
class FactsBase(object):
@ -251,7 +252,13 @@ def main():
gather_subset=dict(default=['!config'], type='list')
)
module = LocalAnsibleModule(argument_spec=argument_spec, supports_check_mode=True)
argument_spec.update(vyos_argument_spec)
module = AnsibleModule(argument_spec=argument_spec,
supports_check_mode=True)
warnings = list()
check_args(module, warnings)
gather_subset = module.params['gather_subset']
@ -303,7 +310,7 @@ def main():
key = 'ansible_net_%s' % key
ansible_facts[key] = value
module.exit_json(ansible_facts=ansible_facts)
module.exit_json(ansible_facts=ansible_facts, warnings=warnings)
if __name__ == '__main__':

@ -72,21 +72,6 @@ commands:
sample:
- set system hostname vyos01
- set system domain-name foo.example.com
start:
description: The time the job started
returned: always
type: str
sample: "2016-11-16 10:38:15.126146"
end:
description: The time the job ended
returned: always
type: str
sample: "2016-11-16 10:38:25.595612"
delta:
description: The time elapsed to perform all operations
returned: always
type: str
sample: "0:00:10.469466"
"""
EXAMPLES = """
@ -112,8 +97,9 @@ EXAMPLES = """
- sub2.example.com
"""
from ansible.module_utils.local import LocalAnsibleModule
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.vyos import get_config, load_config
from ansible.module_utils.vyos import vyos_argument_spec, check_args
def spec_key_to_device_key(key):
@ -181,6 +167,15 @@ def spec_to_commands(want, have):
return commands
def map_param_to_obj(module):
return {
'host_name': module.params['host_name'],
'domain_name': module.params['domain_name'],
'domain_search': module.params['domain_search'],
'name_server': module.params['name_server'],
'state': module.params['state']
}
def main():
argument_spec = dict(
@ -191,14 +186,20 @@ def main():
state=dict(type='str', default='present', choices=['present', 'absent']),
)
module = LocalAnsibleModule(
argument_spec.update(vyos_argument_spec)
module = AnsibleModule(
argument_spec=argument_spec,
supports_check_mode=True,
mutually_exclusive=[('domain_name', 'domain_search')],
)
result = {'changed': False}
want = dict(module.params)
warnings = list()
check_args(module, warnings)
result = {'changed': False, 'warnings': warnings}
want = map_param_to_obj(module)
have = config_to_dict(module)
commands = spec_to_commands(want, have)

@ -0,0 +1,87 @@
#
# (c) 2016 Red Hat Inc.
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import sys
import copy
from ansible.plugins.action.normal import ActionModule as _ActionModule
from ansible.utils.path import unfrackpath
from ansible.plugins import connection_loader
from ansible.compat.six import iteritems
from ansible.module_utils.vyos import vyos_argument_spec
from ansible.module_utils.basic import AnsibleFallbackNotFound
from ansible.module_utils._text import to_bytes
class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
pc = copy.deepcopy(self._play_context)
pc.connection = 'network_cli'
pc.port = provider['port'] or self._play_context.port
pc.remote_user = provider['username'] or self._play_context.connection_user
pc.password = provider['password'] or self._play_context.password
socket_path = self._get_socket_path(pc)
if not os.path.exists(socket_path):
# start the connection if it isn't started
connection = self._shared_loader_obj.connection_loader.get('persistent', pc, sys.stdin)
connection.exec_command('EXEC: show version')
task_vars['ansible_socket'] = socket_path
return super(ActionModule, self).run(tmp, task_vars)
def _get_socket_path(self, play_context):
ssh = connection_loader.get('ssh', class_only=True)
cp = ssh._create_control_path(play_context.remote_addr, play_context.port, play_context.remote_user)
path = unfrackpath("$HOME/.ansible/pc")
return cp % dict(directory=path)
def load_provider(self):
provider = self._task.args.get('provider', {})
for key, value in iteritems(eos_argument_spec):
if key in self._task.args:
provider[key] = self._task.args[key]
elif 'fallback' in value:
provider[key] = self._fallback(value['fallback'])
elif key not in provider:
provider[key] = None
self._task.args['provider'] = provider
def _fallback(self, fallback):
strategy = fallback[0]
args = []
kwargs = {}
for item in fallback[1:]:
if isinstance(item, dict):
kwargs = item
else:
args = item
try:
return strategy(*args, **kwargs)
except AnsibleFallbackNotFound:
pass

@ -19,10 +19,114 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from ansible.plugins.action import ActionBase
from ansible.plugins.action.net_config import ActionModule as NetActionModule
import os
import re
import time
import glob
class ActionModule(NetActionModule, ActionBase):
pass
from ansible.plugins.action.vyos import ActionModule as _ActionModule
from ansible.module_utils._text import to_text
from ansible.module_utils.six.moves.urllib.parse import urlsplit
from ansible.utils.vars import merge_hash
PRIVATE_KEYS_RE = re.compile('__.+__')
class ActionModule(_ActionModule):
def run(self, tmp=None, task_vars=None):
if self._task.args.get('src'):
try:
self._handle_template()
except ValueError as exc:
return dict(failed=True, msg=exc.message)
if self._play_context.connection == 'local':
result = self.normal(tmp, task_vars)
else:
result = super(ActionModule, self).run(tmp, task_vars)
if self._task.args.get('backup') and result.get('__backup__'):
# User requested backup and no error occurred in module.
# NOTE: If there is a parameter error, _backup key may not be in results.
filepath = self._write_backup(task_vars['inventory_hostname'],
result['__backup__'])
result['backup_path'] = filepath
# strip out any keys that have two leading and two trailing
# underscore characters
for key in result.keys():
if PRIVATE_KEYS_RE.match(key):
del result[key]
return result
def normal(self, tmp=None, task_vars=None):
if task_vars is None:
task_vars = dict()
#results = super(ActionModule, self).run(tmp, task_vars)
# remove as modules might hide due to nolog
#del results['invocation']['module_args']
results = {}
results = merge_hash(results, self._execute_module(tmp=tmp, task_vars=task_vars))
# hack to keep --verbose from showing all the setup module results
if self._task.action == 'setup':
results['_ansible_verbose_override'] = True
return results
def _get_working_path(self):
cwd = self._loader.get_basedir()
if self._task._role is not None:
cwd = self._task._role._role_path
return cwd
def _write_backup(self, host, contents):
backup_path = self._get_working_path() + '/backup'
if not os.path.exists(backup_path):
os.mkdir(backup_path)
for fn in glob.glob('%s/%s*' % (backup_path, host)):
os.remove(fn)
tstamp = time.strftime("%Y-%m-%d@%H:%M:%S", time.localtime(time.time()))
filename = '%s/%s_config.%s' % (backup_path, host, tstamp)
open(filename, 'w').write(contents)
return filename
def _handle_template(self):
src = self._task.args.get('src')
working_path = self._get_working_path()
if os.path.isabs(src) or urlsplit('src').scheme:
source = src
else:
source = self._loader.path_dwim_relative(working_path, 'templates', src)
if not source:
source = self._loader.path_dwim_relative(working_path, src)
if not os.path.exists(source):
raise ValueError('path specified in src not found')
try:
with open(source, 'r') as f:
template_data = to_text(f.read())
except IOError:
return dict(failed=True, msg='unable to load src file')
# Create a template search path in the following order:
# [working_path, self_role_path, dependent_role_paths, dirname(source)]
searchpath = [working_path]
if self._task._role is not None:
searchpath.append(self._task._role._role_path)
if hasattr(self._task, "_block:"):
dep_chain = self._task._block.get_dep_chain()
if dep_chain is not None:
for role in dep_chain:
searchpath.append(role._role_path)
searchpath.append(os.path.dirname(source))
self._templar.environment.loader.searchpath = searchpath
self._task.args['src'] = self._templar.template(template_data)

@ -1,75 +0,0 @@
#
# (c) 2015, Peter Sprygada <psprygada@ansible.com>
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
class ModuleDocFragment(object):
# Standard files documentation fragment
DOCUMENTATION = """
options:
host:
description:
- Specifies the DNS host name or address for connecting to the remote
device over the specified transport. The value of host is used as
the destination address for the transport.
required: true
port:
description:
- Specifies the port to use when building the connection to the remote
device.
required: false
default: 22
username:
description:
- Configures the username to use to authenticate the connection to
the remote device. This value is used to authenticate
the SSH session. If the value is not specified in the task, the
value of environment variable C(ANSIBLE_NET_USERNAME) will be used instead.
required: false
password:
description:
- Specifies the password to use to authenticate the connection to
the remote device. This value is used to authenticate
the SSH session. If the value is not specified in the task, the
value of environment variable C(ANSIBLE_NET_PASSWORD) will be used instead.
required: false
default: null
timeout:
description:
- Specifies the timeout in seconds for communicating with the network device
for either connecting or sending commands. If the timeout is
exceeded before the operation is completed, the module will error.
require: false
default: 10
ssh_keyfile:
description:
- Specifies the SSH key to use to authenticate the connection to
the remote device. This value is the path to the
key used to authenticate the SSH session. If the value is not specified
in the task, the value of environment variable C(ANSIBLE_NET_SSH_KEYFILE)
will be used instead.
required: false
provider:
description:
- Convenience method that allows all I(vyos) arguments to be passed as
a dict object. All constraints (required, choices, etc) must be
met either by individual arguments or values in this dict.
required: false
default: null
"""

@ -19,45 +19,15 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import json
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.errors import AnsibleModuleExit
from ansible.compat.tests.mock import patch
from ansible.modules.network.vyos import vyos_command
from ansible.module_utils import basic
from ansible.module_utils._text import to_bytes
from .vyos_module import TestVyosModule, load_fixture, set_module_args
class TestVyosCommandModule(TestVyosModule):
fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures')
fixture_data = {}
def set_module_args(args):
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
def load_fixture(name):
path = os.path.join(fixture_path, name)
if path in fixture_data:
return fixture_data[path]
with open(path) as f:
data = f.read()
try:
data = json.loads(data)
except:
pass
fixture_data[path] = data
return data
class TestVyosCommandModule(unittest.TestCase):
module = vyos_command
def setUp(self):
self.mock_run_commands = patch('ansible.modules.network.vyos.vyos_command.run_commands')
@ -66,8 +36,7 @@ class TestVyosCommandModule(unittest.TestCase):
def tearDown(self):
self.mock_run_commands.stop()
def execute_module(self, failed=False, changed=False):
def load_fixtures(self, commands=None):
def load_from_file(*args, **kwargs):
module, commands = args
output = list()
@ -84,18 +53,6 @@ class TestVyosCommandModule(unittest.TestCase):
self.run_commands.side_effect = load_from_file
with self.assertRaises(AnsibleModuleExit) as exc:
vyos_command.main()
result = exc.exception.result
if failed:
self.assertTrue(result.get('failed'))
else:
self.assertEqual(result.get('changed'), changed, result)
return result
def test_vyos_command_simple(self):
set_module_args(dict(commands=['show version']))
result = self.execute_module()

@ -20,45 +20,16 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import json
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.errors import AnsibleModuleExit
from ansible.compat.tests.mock import patch
from ansible.modules.network.vyos import vyos_config
from ansible.module_utils import basic
from ansible.module_utils._text import to_bytes
from .vyos_module import TestVyosModule, load_fixture, set_module_args
fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures')
fixture_data = {}
class TestVyosConfigModule(TestVyosModule):
def set_module_args(args):
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
def load_fixture(name):
path = os.path.join(fixture_path, name)
if path in fixture_data:
return fixture_data[path]
with open(path) as f:
data = f.read()
try:
data = json.loads(data)
except:
pass
fixture_data[path] = data
return data
class TestVyosConfigModule(unittest.TestCase):
module = vyos_config
def setUp(self):
self.mock_get_config = patch('ansible.modules.network.vyos.vyos_config.get_config')
@ -75,30 +46,11 @@ class TestVyosConfigModule(unittest.TestCase):
self.mock_load_config.stop()
self.mock_run_commands.stop()
def execute_module(self, failed=False, changed=False, commands=None, sort=True, defaults=False):
config_file = 'vyos_config_defaults.cfg' if defaults else 'vyos_config_config.cfg'
def load_fixtures(self, commands=None):
config_file = 'vyos_config_config.cfg'
self.get_config.return_value = load_fixture(config_file)
self.load_config.return_value = None
with self.assertRaises(AnsibleModuleExit) as exc:
vyos_config.main()
result = exc.exception.result
if failed:
self.assertTrue(result['failed'], result)
else:
self.assertEqual(result.get('changed'), changed, result)
if commands:
if sort:
self.assertEqual(sorted(commands), sorted(result['commands']), result['commands'])
else:
self.assertEqual(commands, result['commands'], result['commands'])
return result
def test_vyos_config_unchanged(self):
src = load_fixture('vyos_config_config.cfg')
set_module_args(dict(src=src))

@ -19,45 +19,16 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import json
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.errors import AnsibleModuleExit
from ansible.compat.tests.mock import patch
from ansible.modules.network.vyos import vyos_facts
from ansible.module_utils import basic
from ansible.module_utils._text import to_bytes
from .vyos_module import TestVyosModule, load_fixture, set_module_args
fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures')
fixture_data = {}
class TestVyosFactsModule(TestVyosModule):
def set_module_args(args):
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
def load_fixture(name):
path = os.path.join(fixture_path, name)
if path in fixture_data:
return fixture_data[path]
with open(path) as f:
data = f.read()
try:
data = json.loads(data)
except:
pass
fixture_data[path] = data
return data
class TestVyosFactsModule(unittest.TestCase):
module = vyos_facts
def setUp(self):
self.mock_run_commands = patch('ansible.modules.network.vyos.vyos_facts.run_commands')
@ -66,8 +37,7 @@ class TestVyosFactsModule(unittest.TestCase):
def tearDown(self):
self.mock_run_commands.stop()
def execute_module(self, failed=False, changed=False):
def load_fixtures(self, commands=None):
def load_from_file(*args, **kwargs):
module, commands = args
output = list()
@ -84,18 +54,6 @@ class TestVyosFactsModule(unittest.TestCase):
self.run_commands.side_effect = load_from_file
with self.assertRaises(AnsibleModuleExit) as exc:
vyos_facts.main()
result = exc.exception.result
if failed:
self.assertTrue(result.get('failed'))
else:
self.assertEqual(result.get('changed'), changed, result)
return result
def test_vyos_facts_default(self):
set_module_args(dict(gather_subset='default'))
result = self.execute_module()

@ -21,46 +21,16 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import json
import ansible.module_utils.basic
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch, MagicMock
from ansible.errors import AnsibleModuleExit
from ansible.compat.tests.mock import patch
from ansible.modules.network.vyos import vyos_system
from ansible.module_utils._text import to_bytes
from ansible.module_utils import basic
from .vyos_module import TestVyosModule, load_fixture, set_module_args
fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures')
fixture_data = {}
class TestVyosSystemModule(TestVyosModule):
def set_module_args(args):
json_args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(json_args)
def load_fixture(name):
path = os.path.join(fixture_path, name)
if path in fixture_data:
return fixture_data[path]
with open(path) as f:
data = f.read()
try:
data = json.loads(data)
except:
pass
fixture_data[path] = data
return data
class TestVyosSystemModule(unittest.TestCase):
module = vyos_system
def setUp(self):
self.mock_get_config = patch('ansible.modules.network.vyos.vyos_system.get_config')
@ -73,69 +43,44 @@ class TestVyosSystemModule(unittest.TestCase):
self.mock_get_config.stop()
self.mock_load_config.stop()
def execute_module(self, failed=False, changed=False, commands=None, sort=True):
def load_fixtures(self, commands=None):
self.get_config.return_value = load_fixture('vyos_config_config.cfg')
with self.assertRaises(AnsibleModuleExit) as exc:
vyos_system.main()
result = exc.exception.result
if failed:
self.assertTrue(result['failed'], result)
else:
self.assertEqual(result.get('changed'), changed, result)
if commands:
if sort:
self.assertEqual(sorted(commands), sorted(result['commands']), result['commands'])
else:
self.assertEqual(commands, result['commands'], result['commands'])
return result
def test_vyos_system_hostname(self):
set_module_args(dict(host_name='foo'))
result = self.execute_module(changed=True)
self.assertIn("set system host-name 'foo'", result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ["set system host-name 'foo'"]
self.execute_module(changed=True, commands=commands)
def test_vyos_system_clear_hostname(self):
set_module_args(dict(host_name='foo', state='absent'))
result = self.execute_module(changed=True)
self.assertIn('delete system host-name', result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ["delete system host-name"]
self.execute_module(changed=True, commands=commands)
def test_vyos_remove_single_name_server(self):
set_module_args(dict(name_server=['8.8.4.4'], state='absent'))
result = self.execute_module(changed=True)
self.assertIn("delete system name-server '8.8.4.4'", result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ["delete system name-server '8.8.4.4'"]
self.execute_module(changed=True, commands=commands)
def test_vyos_system_domain_name(self):
set_module_args(dict(domain_name='example2.com'))
result = self.execute_module(changed=True)
self.assertIn("set system domain-name 'example2.com'", result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ["set system domain-name 'example2.com'"]
self.execute_module(changed=True, commands=commands)
def test_vyos_system_clear_domain_name(self):
set_module_args(dict(domain_name='example.com', state='absent'))
result = self.execute_module(changed=True)
self.assertIn('delete system domain-name', result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ['delete system domain-name']
self.execute_module(changed=True, commands=commands)
def test_vyos_system_domain_search(self):
set_module_args(dict(domain_search=['foo.example.com', 'bar.example.com']))
result = self.execute_module(changed=True)
self.assertIn("set system domain-search domain 'foo.example.com'", result['commands'])
self.assertIn("set system domain-search domain 'bar.example.com'", result['commands'])
self.assertEqual(2, len(result['commands']))
commands = ["set system domain-search domain 'foo.example.com'",
"set system domain-search domain 'bar.example.com'"]
self.execute_module(changed=True, commands=commands)
def test_vyos_system_clear_domain_search(self):
set_module_args(dict(domain_search=[]))
result = self.execute_module(changed=True)
self.assertIn('delete system domain-search domain', result['commands'])
self.assertEqual(1, len(result['commands']))
commands = ['delete system domain-search domain']
self.execute_module(changed=True, commands=commands)
def test_vyos_system_no_change(self):
set_module_args(dict(host_name='router', domain_name='example.com', name_server=['8.8.8.8', '8.8.4.4']))
@ -144,5 +89,8 @@ class TestVyosSystemModule(unittest.TestCase):
def test_vyos_system_clear_all(self):
set_module_args(dict(state='absent'))
result = self.execute_module(changed=True)
self.assertEqual(4, len(result['commands']))
commands = ['delete system host-name',
'delete system domain-search domain',
'delete system domain-name',
'delete system name-server']
self.execute_module(changed=True, commands=commands)

@ -0,0 +1,113 @@
# (c) 2016 Red Hat Inc.
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
# Make coding more python3-ish
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
import os
import json
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch
from ansible.module_utils import basic
from ansible.module_utils._text import to_bytes
def set_module_args(args):
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures')
fixture_data = {}
def load_fixture(name):
path = os.path.join(fixture_path, name)
if path in fixture_data:
return fixture_data[path]
with open(path) as f:
data = f.read()
try:
data = json.loads(data)
except:
pass
fixture_data[path] = data
return data
class AnsibleExitJson(Exception):
pass
class AnsibleFailJson(Exception):
pass
class TestVyosModule(unittest.TestCase):
def execute_module(self, failed=False, changed=False, commands=None,
sort=True, defaults=False):
self.load_fixtures(commands)
if failed:
result = self.failed()
self.assertTrue(result['failed'], result)
else:
result = self.changed(changed)
self.assertEqual(result['changed'], changed, result)
if commands:
if sort:
self.assertEqual(sorted(commands), sorted(result['commands']), result['commands'])
else:
self.assertEqual(commands, result['commands'], result['commands'])
return result
def failed(self):
def fail_json(*args, **kwargs):
kwargs['failed'] = True
raise AnsibleFailJson(kwargs)
with patch.object(basic.AnsibleModule, 'fail_json', fail_json):
with self.assertRaises(AnsibleFailJson) as exc:
self.module.main()
result = exc.exception.args[0]
self.assertTrue(result['failed'], result)
return result
def changed(self, changed=False):
def exit_json(*args, **kwargs):
if 'changed' not in kwargs:
kwargs['changed'] = False
raise AnsibleExitJson(kwargs)
with patch.object(basic.AnsibleModule, 'exit_json', exit_json):
with self.assertRaises(AnsibleExitJson) as exc:
self.module.main()
result = exc.exception.args[0]
self.assertEqual(result['changed'], changed, result)
return result
def load_fixtures(self, commands=None):
pass
Loading…
Cancel
Save