From 5dccff29bf685bd8558eb991a4a08f3bedf7cd92 Mon Sep 17 00:00:00 2001 From: Nathaniel Case Date: Mon, 20 Jun 2016 12:11:48 -0400 Subject: [PATCH] Network Module: EOS (#16158) * add new module network * move EOS to NetworkModule * shell.py Python 3.x compatibility * implements the Command class through the connection for eos This implements a new Command class that specifies the cli command and output format. This removes the need to batch commands through the connection * initial add of netcmd module --- lib/ansible/module_utils/eos.py | 508 +++++++++++++++++----------- lib/ansible/module_utils/netcfg.py | 279 ++++++++------- lib/ansible/module_utils/netcmd.py | 202 +++++++++++ lib/ansible/module_utils/network.py | 282 +++++++++++++++ lib/ansible/module_utils/shell.py | 5 + 5 files changed, 937 insertions(+), 339 deletions(-) create mode 100644 lib/ansible/module_utils/netcmd.py create mode 100644 lib/ansible/module_utils/network.py diff --git a/lib/ansible/module_utils/eos.py b/lib/ansible/module_utils/eos.py index b89ad261796..430e521ed78 100644 --- a/lib/ansible/module_utils/eos.py +++ b/lib/ansible/module_utils/eos.py @@ -17,78 +17,224 @@ # along with Ansible. If not, see . # +import collections import re -from ansible.module_utils.basic import AnsibleModule, env_fallback, get_exception -from ansible.module_utils.shell import Shell, ShellError, Command, HAS_PARAMIKO -from ansible.module_utils.netcfg import parse -from ansible.module_utils.urls import fetch_url +from ansible.module_utils.basic import json +from ansible.module_utils.network import NetCli, NetworkError, get_module, Command +from ansible.module_utils.network import add_argument, register_transport, to_list +from ansible.module_utils.netcfg import NetworkConfig +from ansible.module_utils.urls import fetch_url, url_argument_spec NET_PASSWD_RE = re.compile(r"[\r\n]?password: $", re.I) -NET_COMMON_ARGS = dict( - host=dict(required=True), - port=dict(type='int'), - username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])), - password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])), - ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'), - authorize=dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'), - auth_pass=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])), - transport=dict(default='cli', choices=['cli', 'eapi']), - use_ssl=dict(default=True, type='bool'), - provider=dict(type='dict') -) - -CLI_PROMPTS_RE = [ - re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), - re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$") -] - -CLI_ERRORS_RE = [ - re.compile(r"% ?Error"), - re.compile(r"^% \w+", re.M), - re.compile(r"% ?Bad secret"), - re.compile(r"invalid input", re.I), - re.compile(r"(?:incomplete|ambiguous) command", re.I), - re.compile(r"connection timed out", re.I), - re.compile(r"[^\r\n]+ not found", re.I), - re.compile(r"'[^']' +returned error code: ?\d+"), - re.compile(r"[^\r\n]\/bin\/(?:ba)?sh") -] - - -def to_list(val): - if isinstance(val, (list, tuple)): - return list(val) - elif val is not None: - return [val] +EAPI_FORMATS = ['json', 'text'] + +add_argument('use_ssl', dict(default=True, type='bool')) +add_argument('validate_certs', dict(default=True, type='bool')) + +ModuleStub = collections.namedtuple('ModuleStub', 'params fail_json') + +def argument_spec(): + return dict( + # config options + running_config=dict(aliases=['config']), + config_session=dict(default='ansible_session'), + save_config=dict(default=False, aliases=['save']), + force=dict(type='bool', default=False) + ) +eos_argument_spec = argument_spec() + +def get_config(module): + config = module.params['running_config'] + if not config: + config = module.config.get_config(include_defaults=False) + return NetworkConfig(indent=3, contents=config) + +def load_config(module, candidate): + + if not module.params['force']: + config = get_config(module) + commands = candidate.difference(config) else: - return list() + commands = str(candidate) + commands = [str(c).strip() for c in commands] -class Eapi(object): + session = module.params['config_session'] + save_config = module.params['save_config'] - def __init__(self, module): - self.module = module + result = dict(changed=False) - # sets the module_utils/urls.py req parameters - self.module.params['url_username'] = module.params['username'] - self.module.params['url_password'] = module.params['password'] + if commands: + if module._diff: + diff = module.config.load_config(commands, session_name=session) + if diff: + result['diff'] = dict(prepared=diff) + + if not module.check_mode: + module.config.commit_config(session) + if save_config: + module.config.save_config() + else: + module.config.abort_config(session_name=session) + + if not module.check_mode: + module.config(commands) + if save_config: + module.config.save_config() + + result['changed'] = True + result['updates'] = commands + + return result + +def expand_intf_range(interfaces): + match = re.match(r'([a-zA-Z]+)(.+)', interfaces) + if not match: + raise ValueError('could not parse interface range') + + name = match.group(1) + values = match.group(2).split(',') + + indicies = list() + + for val in values: + tokens = val.split('-') + + # single index value to handle + if len(tokens) == 1: + indicies.append(tokens[0]) + + elif len(tokens) == 2: + pairs = list() + mod = 0 + + for token in tokens: + parts = token.split('/') + + if len(parts) == 1: + port = parts[0] + if port == '$': + port = last_port + pairs.append((mod, int(port))) + + elif len(parts) == 2: + mod = int(parts[0]) + port = parts[1] + if port == '$': + port = last_port + pairs.append((mod, int(port))) + + else: + raise ValueError('unable to parse interface') + + if pairs[0][0] == pairs[1][0]: + # same module + mod = pairs[0][0] + start = pairs[0][1] + end = pairs[1][1] + 1 + + for i in range(start, end): + if mod == 0: + indicies.append(i) + else: + indicies.append('%s/%s' % (mod, i)) + else: + # span modules + start_mod, start_port = pairs[0] + end_mod, end_port = pairs[1] + end_port += 1 + + for i in range(start_port, last_port+1): + indicies.append('%s/%s' % (start_mod, i)) + for i in range(first_port, end_port): + indicies.append('%s/%s' % (end_mod, i)) + + return ['%s%s' % (name, index) for index in indicies] + +class EosConfigMixin(object): + + def configure(self, commands, **kwargs): + commands = prepare_config(commands) + responses = self.execute(commands) + responses.pop(0) + return responses + + def get_config(self, **kwargs): + cmd = 'show running-config' + if kwargs.get('include_defaults') is True: + cmd += ' all' + return self.execute([cmd])[0] + + def load_config(self, commands, session_name='ansible_temp_session', **kwargs): + commands = to_list(commands) + commands.insert(0, 'configure session %s' % session_name) + commands.append('show session-config diffs') + commands.append('end') + responses = self.execute(commands) + return responses[-2] + + def replace_config(self, contents, params, **kwargs): + remote_user = params['username'] + remote_path = '/home/%s/ansible-config' % remote_user + + commands = [ + 'bash echo "%s" > %s' % (contents, remote_path), + 'diff running-config file:/%s' % remote_path, + 'config replace file:/%s' % remote_path, + ] + + responses = self.run_commands(commands) + return responses[-2] + + def commit_config(self, session_name): + session = 'configure session %s' % session_name + commands = [session, 'commit', 'no %s' % session] + self.execute(commands) + + def abort_config(self, session_name): + command = 'no configure session %s' % session_name + self.execute([command]) + + def save_config(self): + self.execute(['copy running-config startup-config']) + +class Eapi(EosConfigMixin): + + def __init__(self): self.url = None + self.url_args = ModuleStub(url_argument_spec(), self._error) self.enable = None + self.default_output = 'json' + self._connected = False - def _get_body(self, commands, encoding, reqid=None): + def _error(self, msg): + raise NetworkError(msg, url=self.url) + + def _get_body(self, commands, format, reqid=None): """Create a valid eAPI JSON-RPC request message """ - params = dict(version=1, cmds=commands, format=encoding) + + if format not in EAPI_FORMATS: + msg = 'invalid format, received %s, expected one of %s' % \ + (format, ','.join(EAPI_FORMATS)) + self._error(msg=msg) + + params = dict(version=1, cmds=commands, format=format) return dict(jsonrpc='2.0', id=reqid, method='runCmds', params=params) - def connect(self): - host = self.module.params['host'] - port = self.module.params['port'] + def connect(self, params, **kwargs): + host = params['host'] + port = params['port'] - if self.module.params['use_ssl']: + # sets the module_utils/urls.py req parameters + self.url_args.params['url_username'] = params['username'] + self.url_args.params['url_password'] = params['password'] + self.url_args.params['validate_certs'] = params['validate_certs'] + + if params['use_ssl']: proto = 'https' if not port: port = 443 @@ -98,176 +244,146 @@ class Eapi(object): port = 80 self.url = '%s://%s:%s/command-api' % (proto, host, port) + self._connected = True + + def disconnect(self, **kwargs): + self.url = None + self._connected = False - def authorize(self): - if self.module.params['auth_pass']: - passwd = self.module.params['auth_pass'] + def authorize(self, params, **kwargs): + if params.get('auth_pass'): + passwd = params['auth_pass'] self.enable = dict(cmd='enable', input=passwd) else: self.enable = 'enable' - def send(self, commands, encoding='json'): - """Send commands to the device. - """ - clist = to_list(commands) - - if self.enable is not None: - clist.insert(0, self.enable) - data = self._get_body(clist, encoding) - data = self.module.jsonify(data) + ### implementation of network.Cli ### - headers = {'Content-Type': 'application/json-rpc'} + def run_commands(self, commands): + output = None + cmds = list() + responses = list() - response, headers = fetch_url(self.module, self.url, data=data, - headers=headers, method='POST') + for cmd in commands: + if output and output != cmd.output: + responses.extend(self.execute(cmds, format=output)) + cmds = list() - if headers['status'] != 200: - self.module.fail_json(**headers) + output = cmd.output + cmds.append(str(cmd)) - response = self.module.from_json(response.read()) - if 'error' in response: - err = response['error'] - self.module.fail_json(msg='json-rpc error', commands=commands, **err) + if cmds: + responses.extend(self.execute(cmds, format=output)) - if self.enable: - response['result'].pop(0) + for index, cmd in enumerate(commands): + if cmd.output == 'text': + responses[index] = responses[index].get('output') - return response['result'] + return responses + def execute(self, commands, format='json', **kwargs): + """Send commands to the device. + """ + if self.url is None: + raise NetworkError('Not connected to endpoint.') + if self.enable is not None: + commands.insert(0, self.enable) -class Cli(object): + data = self._get_body(commands, format) + data = json.dumps(data) - def __init__(self, module): - self.module = module - self.shell = None + headers = {'Content-Type': 'application/json-rpc'} - def connect(self, **kwargs): - host = self.module.params['host'] - port = self.module.params['port'] or 22 + response, headers = fetch_url( + self.url_args, self.url, data=data, headers=headers, + method='POST' + ) - username = self.module.params['username'] - password = self.module.params['password'] - key_filename = self.module.params['ssh_keyfile'] + if headers['status'] != 200: + raise NetworkError(**headers) try: - self.shell = Shell(prompts_re=CLI_PROMPTS_RE, errors_re=CLI_ERRORS_RE) - self.shell.open(host, port=port, username=username, password=password, key_filename=key_filename) - except ShellError: - e = get_exception() - msg = 'failed to connect to %s:%s - %s' % (host, port, str(e)) - self.module.fail_json(msg=msg) - - def authorize(self): - passwd = self.module.params['auth_pass'] - self.send(Command('enable', prompt=NET_PASSWD_RE, response=passwd)) - - def send(self, commands): - try: - return self.shell.send(commands) - except ShellError: - e = get_exception() - self.module.fail_json(msg=e.message, commands=commands) + response = json.loads(response.read()) + except ValueError: + raise NetworkError('unable to load response from device') + if 'error' in response: + err = response['error'] + raise NetworkError( + msg=err['message'], code=err['code'], data=err['data'], + commands=commands + ) -class NetworkModule(AnsibleModule): - - def __init__(self, *args, **kwargs): - super(NetworkModule, self).__init__(*args, **kwargs) - self.connection = None - self._config = None - self._connected = False - - @property - def connected(self): - return self._connected - - @property - def config(self): - if not self._config: - self._config = self.get_config() - return self._config - - def _load_params(self): - super(NetworkModule, self)._load_params() - provider = self.params.get('provider') or dict() - for key, value in provider.items(): - if key in NET_COMMON_ARGS: - if self.params.get(key) is None and value is not None: - self.params[key] = value - - def connect(self): - cls = globals().get(str(self.params['transport']).capitalize()) - try: - self.connection = cls(self) - except TypeError: - e = get_exception() - self.fail_json(msg=e.message) - - self.connection.connect() - self.connection.send('terminal length 0') + if self.enable: + response['result'].pop(0) - if self.params['authorize']: - self.connection.authorize() + return response['result'] + def get_config(self, **kwargs): + return self.run_commands(['show running-config'], format='text')[0] +Eapi = register_transport('eapi')(Eapi) + + +class Cli(NetCli, EosConfigMixin): + CLI_PROMPTS_RE = [ + re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), + re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$") + ] + + CLI_ERRORS_RE = [ + re.compile(r"% ?Error"), + re.compile(r"^% \w+", re.M), + re.compile(r"% ?Bad secret"), + re.compile(r"invalid input", re.I), + re.compile(r"(?:incomplete|ambiguous) command", re.I), + re.compile(r"connection timed out", re.I), + re.compile(r"[^\r\n]+ not found", re.I), + re.compile(r"'[^']' +returned error code: ?\d+"), + re.compile(r"[^\r\n]\/bin\/(?:ba)?sh") + ] + + def __init__(self): + super(Cli, self).__init__() + + def connect(self, params, **kwargs): + super(Cli, self).connect(params, kickstart=True, **kwargs) + self.shell.send('terminal length 0') self._connected = True - def configure(self, commands, replace=False): - if replace: - responses = self.config_replace(commands) - else: - responses = self.config_terminal(commands) - return responses - - def config_terminal(self, commands): - commands = to_list(commands) - commands.insert(0, 'configure terminal') - responses = self.execute(commands) - responses.pop(0) + def authorize(self, params, **kwargs): + passwd = params['auth_pass'] + self.execute(Command('enable', prompt=NET_PASSWD_RE, response=passwd)) + + ### implementation of network.Cli ### + + def run_commands(self, commands): + cmds = list(prepare_commands(commands)) + responses = self.execute(cmds) + for index, cmd in enumerate(commands): + if cmd.output == 'json': + try: + responses[index] = json.loads(responses[index]) + except ValueError: + raise NetworkError( + msg='unable to load response from device', + response=responses[index] + ) return responses +Cli = register_transport('cli', default=True)(Cli) - def config_replace(self, commands): - if self.params['transport'] == 'cli': - self.fail_json(msg='config replace only supported over eapi') - cmd = 'configure replace terminal:' - commands = '\n'.join(to_list(commands)) - command = dict(cmd=cmd, input=commands) - return self.execute(command) - - def execute(self, commands, **kwargs): - if not self.connected: - self.connect() - return self.connection.send(commands, **kwargs) - - def disconnect(self): - self.connection.close() - self._connected = False +def prepare_config(commands): + commands = to_list(commands) + commands.insert(0, 'configure terminal') + commands.append('end') + return commands - def parse_config(self, cfg): - return parse(cfg, indent=3) - def get_config(self): - cmd = 'show running-config' - if self.params.get('include_defaults'): - cmd += ' all' - if self.params['transport'] == 'cli': - return self.execute(cmd)[0] +def prepare_commands(commands): + jsonify = lambda x: '%s | json' % x + for cmd in to_list(commands): + if cmd.output == 'json': + cmd = jsonify(cmd) else: - resp = self.execute(cmd, encoding='text') - return resp[0]['output'] - - -def get_module(**kwargs): - """Return instance of NetworkModule - """ - argument_spec = NET_COMMON_ARGS.copy() - if kwargs.get('argument_spec'): - argument_spec.update(kwargs['argument_spec']) - kwargs['argument_spec'] = argument_spec - - module = NetworkModule(**kwargs) - - if module.params['transport'] == 'cli' and not HAS_PARAMIKO: - module.fail_json(msg='paramiko is required but does not appear to be installed') - - return module + cmd = str(cmd) + yield cmd diff --git a/lib/ansible/module_utils/netcfg.py b/lib/ansible/module_utils/netcfg.py index 6f6bbee6e14..71cb57ea651 100644 --- a/lib/ansible/module_utils/netcfg.py +++ b/lib/ansible/module_utils/netcfg.py @@ -18,11 +18,14 @@ # import re +import time import collections import itertools import shlex +import itertools from ansible.module_utils.basic import BOOLEANS_TRUE, BOOLEANS_FALSE +from ansible.module_utils.network import to_list DEFAULT_COMMENT_TOKENS = ['#', '!'] @@ -34,6 +37,13 @@ class ConfigLine(object): self.parents = list() self.raw = None + @property + def line(self): + line = ['set'] + line.extend([p.text for p in self.parents]) + line.append(self.text) + return ' '.join(line) + def __str__(self): return self.raw @@ -49,16 +59,20 @@ def ignore_line(text, tokens=None): if text.startswith(item): return True +def get_next(iterable): + item, next_item = itertools.tee(iterable, 2) + next_item = itertools.islice(next_item, 1, None) + return itertools.izip_longest(item, next_item) + def parse(lines, indent, comment_tokens=None): toplevel = re.compile(r'\S') childline = re.compile(r'^\s*(.+)$') - repl = r'([{|}|;])' ancestors = list() config = list() for line in str(lines).split('\n'): - text = str(re.sub(repl, '', line)).strip() + text = str(re.sub(r'([{};])', '', line)).strip() cfg = ConfigLine(text) cfg.raw = line @@ -108,11 +122,23 @@ class NetworkConfig(object): def items(self): return self._config + @property + def lines(self): + lines = list() + for item, next_item in get_next(self.items): + if next_item is None: + lines.append(item.line) + elif not next_item.line.startswith(item.line): + lines.append(item.line) + return lines + def __str__(self): - config = collections.OrderedDict() - for item in self._config: - self.expand(item, config) - return '\n'.join(self.flatten(config)) + text = '' + for item in self.items: + if not item.parents: + expand = self.get_section(item.text) + text += '%s\n' % self.get_section(item.text) + return str(text).strip() def load(self, contents): self._config = parse(contents, indent=self.indent) @@ -167,6 +193,45 @@ class NetworkConfig(object): if c.raw not in current_level: current_level[c.raw] = collections.OrderedDict() + def to_lines(self, section): + lines = list() + for entry in section[1:]: + line = ['set'] + line.extend([p.text for p in entry.parents]) + line.append(entry.text) + lines.append(' '.join(line)) + return lines + + def to_block(self, section): + return '\n'.join([item.raw for item in section]) + + def get_section(self, path): + try: + section = self.get_section_objects(path) + if self._device_os == 'junos': + return self.to_lines(section) + return self.to_block(section) + except ValueError: + return list() + + def get_section_objects(self, path): + if not isinstance(path, list): + path = [path] + obj = self.get_object(path) + if not obj: + raise ValueError('path does not exist in config') + return self.expand_section(obj) + + def expand_section(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj.children: + if child in S: + continue + self.expand_section(child, S) + return S + def flatten(self, data, obj=None): if obj is None: obj = list() @@ -237,155 +302,83 @@ class NetworkConfig(object): return self.flatten(diffs) - def _build_children(self, children, parents=None, offset=0): - for item in children: - line = ConfigLine(item) - line.raw = item.rjust(len(item) + offset) - if parents: - line.parents = parents - parents[-1].children.append(line) - yield line - - def add(self, lines, parents=None): - offset = 0 + def replace(self, replace, text=None, regex=None, parents=None, + add_if_missing=False, ignore_whitespace=False): + match = None - config = list() - parent = None parents = parents or list() + if text is None and regex is None: + raise ValueError('missing required arguments') - for item in parents: - line = ConfigLine(item) - line.raw = item.rjust(len(item) + offset) - config.append(line) - if parent: - parent.children.append(line) - if parent.parents: - line.parents.append(*parent.parents) - line.parents.append(parent) - parent = line - offset += self.indent - - self._config.extend(config) - self._config.extend(list(self._build_children(lines, config, offset))) - - - -class Conditional(object): - """Used in command modules to evaluate waitfor conditions - """ - - OPERATORS = { - 'eq': ['eq', '=='], - 'neq': ['neq', 'ne', '!='], - 'gt': ['gt', '>'], - 'ge': ['ge', '>='], - 'lt': ['lt', '<'], - 'le': ['le', '<='], - 'contains': ['contains'] - } - - def __init__(self, conditional, encoding='json'): - self.raw = conditional - self.encoding = encoding - - key, op, val = shlex.split(conditional) - self.key = key - self.func = self.func(op) - self.value = self._cast_value(val) - - def __call__(self, data): - value = self.get_value(dict(result=data)) - return self.func(value) - - def _cast_value(self, value): - if value in BOOLEANS_TRUE: - return True - elif value in BOOLEANS_FALSE: - return False - elif re.match(r'^\d+\.d+$', value): - return float(value) - elif re.match(r'^\d+$', value): - return int(value) - else: - return unicode(value) + if not regex: + regex = ['^%s$' % text] - def func(self, oper): - for func, operators in self.OPERATORS.items(): - if oper in operators: - return getattr(self, func) - raise AttributeError('unknown operator: %s' % oper) + patterns = [re.compile(r, re.I) for r in to_list(regex)] - def get_value(self, result): - if self.encoding in ['json', 'text']: - return self.get_json(result) - elif self.encoding == 'xml': - return self.get_xml(result.get('result')) - - def get_xml(self, result): - parts = self.key.split('.') + for item in self.items: + for regexp in patterns: + string = ignore_whitespace is True and item.text or item.raw + if regexp.search(item.text): + if item.text != replace: + if parents == [p.text for p in item.parents]: + match = item + break - value_index = None - match = re.match(r'^\S+(\[)(\d+)\]', parts[-1]) if match: - start, end = match.regs[1] - parts[-1] = parts[-1][0:start] - value_index = int(match.group(2)) - - path = '/'.join(parts[1:]) - path = '/%s' % path - path += '/text()' - - index = int(re.match(r'result\[(\d+)\]', parts[0]).group(1)) - values = result[index].xpath(path) - - if value_index is not None: - return values[value_index].strip() - return [v.strip() for v in values] - - def get_json(self, result): - parts = re.split(r'\.(?=[^\]]*(?:\[|$))', self.key) - for part in parts: - match = re.findall(r'\[(\S+?)\]', part) - if match: - key = part[:part.find('[')] - result = result[key] - for m in match: - try: - m = int(m) - except ValueError: - m = str(m) - result = result[m] - else: - result = result.get(part) - return result - - def number(self, value): - if '.' in str(value): - return float(value) - else: - return int(value) - - def eq(self, value): - return value == self.value + match.text = replace + indent = len(match.raw) - len(match.raw.lstrip()) + match.raw = replace.rjust(len(replace) + indent) - def neq(self, value): - return value != self.value + elif add_if_missing: + self.add(replace, parents=parents) - def gt(self, value): - return self.number(value) > self.value - def ge(self, value): - return self.number(value) >= self.value - - def lt(self, value): - return self.number(value) < self.value - - def le(self, value): - return self.number(value) <= self.value + def add(self, lines, parents=None): + """Adds one or lines of configuration + """ - def contains(self, value): - return str(self.value) in value + ancestors = list() + offset = 0 + obj = None + ## global config command + if not parents: + for line in to_list(lines): + item = ConfigLine(line) + item.raw = line + if item not in self.items: + self.items.append(item) + else: + for index, p in enumerate(parents): + try: + i = index + 1 + obj = self.get_section_objects(parents[:i])[0] + ancestors.append(obj) + + except ValueError: + # add parent to config + offset = index * self.indent + obj = ConfigLine(p) + obj.raw = p.rjust(len(p) + offset) + if ancestors: + obj.parents = list(ancestors) + ancestors[-1].children.append(obj) + self.items.append(obj) + ancestors.append(obj) + + # add child objects + for line in to_list(lines): + # check if child already exists + for child in ancestors[-1].children: + if child.text == line: + break + else: + offset = len(parents) * self.indent + item = ConfigLine(line) + item.raw = line.rjust(len(line) + offset) + item.parents = ancestors + ancestors[-1].children.append(item) + self.items.append(item) diff --git a/lib/ansible/module_utils/netcmd.py b/lib/ansible/module_utils/netcmd.py new file mode 100644 index 00000000000..11254b78c98 --- /dev/null +++ b/lib/ansible/module_utils/netcmd.py @@ -0,0 +1,202 @@ +# +# (c) 2015 Peter Sprygada, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# + +import re +import time +import collections +import itertools +import shlex + +from ansible.module_utils.basic import BOOLEANS_TRUE, BOOLEANS_FALSE + +class Conditional(object): + """Used in command modules to evaluate waitfor conditions + """ + + OPERATORS = { + 'eq': ['eq', '=='], + 'neq': ['neq', 'ne', '!='], + 'gt': ['gt', '>'], + 'ge': ['ge', '>='], + 'lt': ['lt', '<'], + 'le': ['le', '<='], + 'contains': ['contains'], + 'matches': ['matches'] + } + + def __init__(self, conditional, encoding='json'): + self.raw = conditional + self.encoding = encoding + + key, op, val = shlex.split(conditional) + self.key = key + self.func = self.func(op) + self.value = self._cast_value(val) + + def __call__(self, data): + value = self.get_value(dict(result=data)) + return self.func(value) + + def _cast_value(self, value): + if value in BOOLEANS_TRUE: + return True + elif value in BOOLEANS_FALSE: + return False + elif re.match(r'^\d+\.d+$', value): + return float(value) + elif re.match(r'^\d+$', value): + return int(value) + else: + return unicode(value) + + def func(self, oper): + for func, operators in self.OPERATORS.items(): + if oper in operators: + return getattr(self, func) + raise AttributeError('unknown operator: %s' % oper) + + def get_value(self, result): + if self.encoding in ['json', 'text']: + return self.get_json(result) + elif self.encoding == 'xml': + return self.get_xml(result.get('result')) + + def get_xml(self, result): + parts = self.key.split('.') + + value_index = None + match = re.match(r'^\S+(\[)(\d+)\]', parts[-1]) + if match: + start, end = match.regs[1] + parts[-1] = parts[-1][0:start] + value_index = int(match.group(2)) + + path = '/'.join(parts[1:]) + path = '/%s' % path + path += '/text()' + + index = int(re.match(r'result\[(\d+)\]', parts[0]).group(1)) + values = result[index].xpath(path) + + if value_index is not None: + return values[value_index].strip() + return [v.strip() for v in values] + + def get_json(self, result): + parts = re.split(r'\.(?=[^\]]*(?:\[|$))', self.key) + for part in parts: + match = re.findall(r'\[(\S+?)\]', part) + if match: + key = part[:part.find('[')] + result = result[key] + for m in match: + try: + m = int(m) + except ValueError: + m = str(m) + result = result[m] + else: + result = result.get(part) + return result + + def number(self, value): + if '.' in str(value): + return float(value) + else: + return int(value) + + def eq(self, value): + return value == self.value + + def neq(self, value): + return value != self.value + + def gt(self, value): + return self.number(value) > self.value + + def ge(self, value): + return self.number(value) >= self.value + + def lt(self, value): + return self.number(value) < self.value + + def le(self, value): + return self.number(value) <= self.value + + def contains(self, value): + return str(self.value) in value + + def matches(self, value): + match = re.search(value, self.value, re.M) + return match is not None + + +class FailedConditionsError(Exception): + + def __init__(self, msg, failed_conditions): + super(FailedConditionsError, self).__init__(msg) + self.failed_conditions = failed_conditions + +class CommandRunner(collections.Mapping): + + def __init__(self, module): + self.module = module + + self.items = dict() + self.conditionals = set() + + self.retries = 10 + self.interval = 1 + + def __getitem__(self, key): + return self.items[key] + + def __len__(self): + return len(self.items) + + def __iter__(self): + return iter(self.items) + + def add_command(self, command, output=None): + self.module.cli.add_commands(command, output=output) + + def add_conditional(self, condition): + self.conditionals.add(Conditional(condition)) + + def run_commands(self): + responses = self.module.cli.run_commands() + for cmd, resp in itertools.izip(self.module.cli.commands, responses): + self.items[str(cmd)] = resp + + def run(self): + while self.retries > 0: + self.run_commands() + for item in list(self.conditionals): + if item(self.items.values()): + self.conditionals.remove(item) + + if not self.conditionals: + break + + time.sleep(self.interval) + self.retries -= 1 + else: + failed_conditions = [item.raw for item in self.conditionals] + raise FailedConditionsError('timeout waiting for value', failed_conditions) + diff --git a/lib/ansible/module_utils/network.py b/lib/ansible/module_utils/network.py new file mode 100644 index 00000000000..19c0c2c5320 --- /dev/null +++ b/lib/ansible/module_utils/network.py @@ -0,0 +1,282 @@ +# +# (c) 2015 Peter Sprygada, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.basic import env_fallback, get_exception +from ansible.module_utils.shell import Shell, ShellError, HAS_PARAMIKO + +NET_TRANSPORT_ARGS = dict( + host=dict(required=True), + port=dict(type='int'), + username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])), + password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])), + ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'), + authorize=dict(default=False, fallback=(env_fallback, ['ANSIBLE_NET_AUTHORIZE']), type='bool'), + auth_pass=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_AUTH_PASS'])), + provider=dict(type='dict'), + transport=dict(choices=list()), + timeout=dict(default=10, type='int') +) + +NET_CONNECTION_ARGS = dict() + +NET_CONNECTIONS = dict() + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + +def connect(module): + try: + if not module.connected: + module.connection.connect(module.params) + if module.params['authorize']: + module.connection.authorize(module.params) + except NetworkError: + exc = get_exception() + module.fail_json(msg=exc.message) + +def disconnect(module): + try: + if module.connected: + module.connection.disconnect() + except NetworkError: + exc = get_exception() + module.fail_json(msg=exc.message) + + +class Command(object): + + def __init__(self, command, output=None, prompt=None, response=None): + self.command = command + self.output = output + self.prompt = prompt + self.response = response + self.conditions = set() + + def __str__(self): + return self.command + +class Cli(object): + + def __init__(self, connection): + self.connection = connection + self.default_output = connection.default_output or 'text' + self.commands = list() + + def __call__(self, commands, output=None): + commands = self.to_command(commands, output) + return self.connection.run_commands(commands) + + def to_command(self, commands, output=None): + output = output or self.default_output + objects = list() + for cmd in to_list(commands): + if not isinstance(cmd, Command): + cmd = Command(cmd, output) + objects.append(cmd) + return objects + + def add_commands(self, commands, output=None): + commands = self.to_command(commands, output) + self.commands.extend(commands) + + def run_commands(self): + return self.connection.run_commands(self.commands) + +class Config(object): + + def __init__(self, connection): + self.connection = connection + + def invoke(self, method, *args, **kwargs): + try: + return method(*args, **kwargs) + except AttributeError: + exc = get_exception() + raise NetworkError('undefined method "%s"' % method.__name__, exc=str(exc)) + except NetworkError: + if raise_exc: + raise + exc = get_exception() + self.fail_json(msg=exc.message, **exc.kwargs) + except NotImplementedError: + raise NetworkError('method not supported "%s"' % method.__name__) + + def __call__(self, commands): + lines = to_list(commands) + return self.invoke(self.connection.configure, commands) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.invoke(self.connection.load_config, commands, **kwargs) + + def get_config(self, **kwargs): + return self.invoke(self.connection.get_config, **kwargs) + + def commit_config(self, **kwargs): + return self.invoke(self.connection.commit_config, **kwargs) + + def abort_config(self, **kwargs): + return self.invoke(self.connection.abort_config, **kwargs) + + def save_config(self): + return self.invoke(self.connection.save_config) + + +class NetworkError(Exception): + + def __init__(self, msg, **kwargs): + super(NetworkError, self).__init__(msg) + self.kwargs = kwargs + + +class NetworkModule(AnsibleModule): + + def __init__(self, *args, **kwargs): + super(NetworkModule, self).__init__(*args, **kwargs) + self.connection = None + self._cli = None + self._config = None + + @property + def cli(self): + if not self.connected: + connect(self) + if self._cli: + return self._cli + self._cli = Cli(self.connection) + return self._cli + + @property + def config(self): + if not self.connected: + connect(self) + if self._config: + return self._config + self._config = Config(self.connection) + return self._config + + @property + def connected(self): + return self.connection._connected + + def _load_params(self): + super(NetworkModule, self)._load_params() + provider = self.params.get('provider') or dict() + for key, value in provider.items(): + for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]: + if key in args: + if self.params.get(key) is None and value is not None: + self.params[key] = value + + +class NetCli(object): + """Basic paramiko-based ssh transport any NetworkModule can use.""" + + def __init__(self): + if not HAS_PARAMIKO: + raise NetworkError( + msg='paramiko is required but does not appear to be installed. ' + 'It can be installed using `pip install paramiko`' + ) + + self.shell = None + self._connected = False + self.default_output = 'text' + + def connect(self, params, kickstart, **kwargs): + host = params['host'] + port = params.get('port') or 22 + + username = params['username'] + password = params.get('password') + key_file = params.get('ssh_keyfile') + timeout = params['timeout'] + + try: + self.shell = Shell( + kickstart=kickstart, + prompts_re=self.CLI_PROMPTS_RE, + errors_re=self.CLI_ERRORS_RE, + ) + self.shell.open( + host, port=port, username=username, password=password, + key_filename=key_file, timeout=timeout, + ) + except ShellError: + exc = get_exception() + raise NetworkError( + msg='failed to connect to %s:%s' % (host, port), exc=str(exc) + ) + + def disconnect(self, **kwargs): + self._connected = False + self.shell.close() + + def execute(self, commands, **kwargs): + try: + return self.shell.send(commands) + except ShellError: + exc = get_exception() + raise NetworkError(exc.message, commands=commands) + + +def get_module(connect_on_load=True, **kwargs): + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec['transport']['choices'] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get('argument_spec'): + argument_spec.update(kwargs['argument_spec']) + kwargs['argument_spec'] = argument_spec + + module = NetworkModule(**kwargs) + + try: + transport = module.params['transport'] or '__default__' + cls = NET_CONNECTIONS[transport] + module.connection = cls() + except KeyError: + module.fail_json(msg='Unknown transport or no default transport specified') + except (TypeError, NetworkError): + exc = get_exception() + module.fail_json(msg=exc.message) + + if connect_on_load: + connect(module) + + return module + +def register_transport(transport, default=False): + def register(cls): + NET_CONNECTIONS[transport] = cls + if default: + NET_CONNECTIONS['__default__'] = cls + return cls + return register + +def add_argument(key, value): + NET_CONNECTION_ARGS[key] = value + diff --git a/lib/ansible/module_utils/shell.py b/lib/ansible/module_utils/shell.py index 641f6927ab9..5e17df25731 100644 --- a/lib/ansible/module_utils/shell.py +++ b/lib/ansible/module_utils/shell.py @@ -19,6 +19,8 @@ import re import socket +from ansible.module_utils.basic import get_exception + # py2 vs py3; replace with six via ziploader try: from StringIO import StringIO @@ -156,6 +158,9 @@ class Shell(object): responses.append(self.receive(command)) except socket.timeout: raise ShellError("timeout trying to send command", cmd) + except socket.error: + exc = get_exception() + raise ShellError("problem sending command to host: %s" % exc.message) return responses def close(self):