diff --git a/lib/ansible/module_utils/junos.py b/lib/ansible/module_utils/junos.py index 27243904e2f..72e323b721b 100644 --- a/lib/ansible/module_utils/junos.py +++ b/lib/ansible/module_utils/junos.py @@ -16,20 +16,49 @@ # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . # +from distutils.version import LooseVersion from ansible.module_utils.basic import AnsibleModule, env_fallback from ansible.module_utils.shell import Shell, HAS_PARAMIKO from ansible.module_utils.netcfg import parse +try: + from jnpr.junos import Device + from jnpr.junos.utils.config import Config + from jnpr.junos.version import VERSION + from jnpr.junos.exception import RpcError, ConfigLoadError, CommitError + from jnpr.junos.exception import LockError, UnlockError + if not LooseVersion(VERSION) >= LooseVersion('1.2.2'): + HAS_PYEZ = False + else: + HAS_PYEZ = True +except ImportError: + HAS_PYEZ = False + +try: + import jxmlease + HAS_JXMLEASE = True +except ImportError: + HAS_JXMLEASE = False + +try: + from lxml import etree +except ImportError: + import xml.etree.ElementTree as etree + + NET_COMMON_ARGS = dict( host=dict(required=True), - port=dict(default=22, type='int'), + port=dict(type='int'), username=dict(fallback=(env_fallback, ['ANSIBLE_NET_USERNAME'])), password=dict(no_log=True, fallback=(env_fallback, ['ANSIBLE_NET_PASSWORD'])), ssh_keyfile=dict(fallback=(env_fallback, ['ANSIBLE_NET_SSH_KEYFILE']), type='path'), - provider=dict() + timeout=dict(default=0, type='int'), + transport=dict(default='netconf', choices=['cli', 'netconf']), + provider=dict(type='dict') ) + def to_list(val): if isinstance(val, (list, tuple)): return list(val) @@ -38,6 +67,16 @@ def to_list(val): else: return list() +def xml_to_json(val): + if isinstance(val, basestring): + return jxmlease.parse(val) + else: + return jxmlease.parse_etree(val) + +def xml_to_string(val): + return etree.tostring(val) + + class Cli(object): def __init__(self, module): @@ -55,27 +94,181 @@ class Cli(object): self.shell = Shell() try: - self.shell.open(host, port=port, username=username, password=password, key_filename=key_filename) + self.shell.open(host, port=port, username=username, + password=password, key_filename=key_filename) except Exception, exc: - msg = 'failed to connecto to %s:%s - %s' % (host, port, str(exc)) + msg = 'failed to connect to %s:%s - %s' % (host, port, str(exc)) self.module.fail_json(msg=msg) - def send(self, commands): + if self.shell._matched_prompt.strip().endswith('%'): + self.shell.send('cli') + self.shell.send('set cli screen-length 0') + + def run_commands(self, commands, **kwargs): return self.shell.send(commands) + def configure(self, commands, **kwargs): + commands = to_list(commands) + commands.insert(0, 'configure') + + if kwargs.get('comment'): + commands.append('commit comment "%s"' % kwargs.get('comment')) + else: + commands.append('commit and-quit') + + responses = self.shell.send(commands) + responses.pop(0) + responses.pop() + return responses + + def disconnect(self): + self.shell.close() + + +class Netconf(object): + + def __init__(self, module): + self.module = module + self.device = None + self.config = None + self._locked = False + + def _fail(self, msg): + if self.device: + if self._locked: + self.config.unlock() + self.disconnect() + self.module.fail_json(msg=msg) + + def connect(self, **kwargs): + try: + host = self.module.params['host'] + port = self.module.params['port'] or 830 + + user = self.module.params['username'] + passwd = self.module.params['password'] + + self.device = Device(host, user=user, passwd=passwd, port=port, + gather_facts=False).open() + + self.config = Config(self.device) + + except Exception, exc: + self._fail('unable to connect to %s: %s' % (host, str(exc))) + + def run_commands(self, commands, **kwargs): + response = list() + fmt = kwargs.get('format') or 'xml' + + for cmd in to_list(commands): + try: + resp = self.device.cli(command=cmd, format=fmt) + response.append(resp) + except (ValueError, RpcError), exc: + self._fail('Unable to get cli output: %s' % str(exc)) + except Exception, exc: + self._fail('Uncaught exception - please report: %s' % str(exc)) + + return response + + def unlock_config(self): + try: + self.config.unlock() + self._locked = False + except UnlockError, exc: + self.module.log('unable to unlock config: {0}'.format(str(exc))) + + def lock_config(self): + try: + self.config.lock() + self._locked = True + except LockError, exc: + self.module.log('unable to lock config: {0}'.format(str(exc))) + + def check_config(self): + if not self.config.commit_check(): + self._fail(msg='Commit check failed') + + def commit_config(self, comment=None, confirm=None): + try: + kwargs = dict(comment=comment) + if confirm and confirm > 0: + kwargs['confirm'] = confirm + return self.config.commit(**kwargs) + except CommitError, exc: + msg = 'Unable to commit configuration: {0}'.format(str(exc)) + self._fail(msg=msg) + + def load_config(self, candidate, action='replace', comment=None, + confirm=None, format='text', commit=True): + + merge = action == 'merge' + overwrite = action == 'overwrite' + + self.lock_config() + + try: + self.config.load(candidate, format=format, merge=merge, + overwrite=overwrite) + except ConfigLoadError, exc: + msg = 'Unable to load config: {0}'.format(str(exc)) + self._fail(msg=msg) + + diff = self.config.diff() + self.check_config() + if commit and diff: + self.commit_config(comment=comment, confirm=confirm) + + self.unlock_config() + + return diff + + def rollback_config(self, identifier, commit=True, comment=None): + + self.lock_config() + + try: + result = self.config.rollback(identifier) + except Exception, exc: + msg = 'Unable to rollback config: {0}'.format(str(exc)) + self._fail(msg=msg) + + diff = self.config.diff() + if commit: + self.commit_config(comment=comment) + + self.unlock_config() + return diff + + def disconnect(self): + if self.device: + self.device.close() + + def get_facts(self, refresh=True): + if refresh: + self.device.facts_refresh() + return self.device.facts + + def get_config(self): + ele = self.rpc('get_configuration', format='text') + return str(ele.text).strip() + + def rpc(self, name, format='xml', **kwargs): + meth = getattr(self.device.rpc, name) + reply = meth({'format': format}, **kwargs) + return reply + class NetworkModule(AnsibleModule): def __init__(self, *args, **kwargs): super(NetworkModule, self).__init__(*args, **kwargs) self.connection = None - self._config = None + self._connected = False @property - def config(self): - if not self._config: - self._config = self.get_config() - return self._config + def connected(self): + return self._connected def _load_params(self): super(NetworkModule, self)._load_params() @@ -86,33 +279,44 @@ class NetworkModule(AnsibleModule): self.params[key] = value def connect(self): - self.connection = Cli(self) + cls = globals().get(str(self.params['transport']).capitalize()) + self.connection = cls(self) self.connection.connect() - if self.connection.shell._matched_prompt.strip().endswith('%'): - self.execute('cli') - self.execute('set cli screen-length 0') - def configure(self, commands): - commands = to_list(commands) - commands.insert(0, 'configure') - commands.append('commit and-quit') - responses = self.execute(commands) - responses.pop(0) - responses.pop() - return responses + msg = 'connecting to host: {username}@{host}:{port}'.format(**self.params) + self.log(msg) - def execute(self, commands, **kwargs): - return self.connection.send(commands) + self._connected = True + + def load_config(self, commands, **kwargs): + if not self.connected: + self.connect() + return self.connection.load_config(commands, **kwargs) + + def rollback_config(self, identifier, commit=True): + if not self.connected: + self.connect() + return self.connection.rollback_config(identifier) + + def run_commands(self, commands, **kwargs): + if not self.connected: + self.connect() + return self.connection.run_commands(commands, **kwargs) def disconnect(self): - self.connection.close() + if self.connected: + self.connection.disconnect() + self._connected = False - def parse_config(self, cfg): - return parse(cfg, indent=4) + def get_config(self, **kwargs): + if not self.connected: + self.connect() + return self.connection.get_config(**kwargs) - def get_config(self): - cmd = 'show configuration' - return self.execute(cmd)[0] + def get_facts(self, **kwargs): + if not self.connected: + self.connect() + return self.connection.get_facts(**kwargs) def get_module(**kwargs): """Return instance of NetworkModule @@ -126,8 +330,10 @@ def get_module(**kwargs): module = NetworkModule(**kwargs) # HAS_PARAMIKO is set by module_utils/shell.py - if not HAS_PARAMIKO: + if module.params['transport'] == 'cli' and not HAS_PARAMIKO: module.fail_json(msg='paramiko is required but does not appear to be installed') + elif module.params['transport'] == 'netconf' and not HAS_PYEZ: + module.fail_json(msg='junos-eznc >= 1.2.2 is required but does not appear to be installed') module.connect() return module diff --git a/lib/ansible/module_utils/netcfg.py b/lib/ansible/module_utils/netcfg.py index 15011865df5..ff1d138c23e 100644 --- a/lib/ansible/module_utils/netcfg.py +++ b/lib/ansible/module_utils/netcfg.py @@ -96,9 +96,10 @@ def parse(lines, indent, comment_tokens=None): class NetworkConfig(object): - def __init__(self, indent=None, contents=None): + def __init__(self, indent=None, contents=None, device_os=None): self.indent = indent or 1 self._config = list() + self._device_os = device_os if contents: self.load(contents) @@ -225,6 +226,9 @@ class NetworkConfig(object): updates.extend(config) break + if self._device_os == 'junos': + return updates + diffs = dict() for update in updates: if replace == 'block' and update.parents: @@ -278,8 +282,9 @@ class Conditional(object): 'contains': ['contains'] } - def __init__(self, conditional): + def __init__(self, conditional, encoding='json'): self.raw = conditional + self.encoding = encoding key, op, val = shlex.split(conditional) self.key = key @@ -287,11 +292,8 @@ class Conditional(object): self.value = self._cast_value(val) def __call__(self, data): - try: - value = self.get_value(dict(result=data)) - return self.func(value) - except Exception: - raise ValueError(self.key) + value = self.get_value(dict(result=data)) + return self.func(value) def _cast_value(self, value): if value in BOOLEANS_TRUE: @@ -312,6 +314,33 @@ class Conditional(object): raise AttributeError('unknown operator: %s' % oper) def get_value(self, result): + if self.encoding in ['json', 'text']: + return self.get_json(result) + elif self.encoding == 'xml': + return self.get_xml(result.get('result')) + + def get_xml(self, result): + parts = self.key.split('.') + + value_index = None + match = re.match(r'^\S+(\[)(\d+)\]', parts[-1]) + if match: + start, end = match.regs[1] + parts[-1] = parts[-1][0:start] + value_index = int(match.group(2)) + + path = '/'.join(parts[1:]) + path = '/%s' % path + path += '/text()' + + index = int(re.match(r'result\[(\d+)\]', parts[0]).group(1)) + values = result[index].xpath(path) + + if value_index is not None: + return values[value_index].strip() + return [v.strip() for v in values] + + def get_json(self, result): parts = re.split(r'\.(?=[^\]]*(?:\[|$))', self.key) for part in parts: match = re.findall(r'\[(\S+?)\]', part)