From 1a00e2635e67ee82411714c52f24f89fdfe0d353 Mon Sep 17 00:00:00 2001 From: Michael DeHaan Date: Sun, 6 May 2012 17:03:17 -0400 Subject: [PATCH] Further work on making the YAML inventory parser use the new inventory objects. --- lib/ansible/inventory.py | 15 ++-- lib/ansible/inventory_parser_yaml.py | 103 +++++++++++++++++++++++++++ test/TestInventory.py | 41 ++++++----- 3 files changed, 132 insertions(+), 27 deletions(-) create mode 100644 lib/ansible/inventory_parser_yaml.py diff --git a/lib/ansible/inventory.py b/lib/ansible/inventory.py index 38a094e9842..7372c8b9b0d 100644 --- a/lib/ansible/inventory.py +++ b/lib/ansible/inventory.py @@ -23,6 +23,7 @@ import os import constants as C import subprocess from ansible.inventory_parser import InventoryParser +from ansible.inventory_parser_yaml import InventoryParserYaml from ansible.inventory_script import InventoryScript from ansible.group import Group from ansible.host import Host @@ -36,9 +37,6 @@ class Inventory(object): def __init__(self, host_list=C.DEFAULT_HOST_LIST): - # FIXME: re-support YAML inventory format - # FIXME: re-support external inventory script (?) - self.host_list = host_list self.groups = [] self._restriction = None @@ -52,9 +50,14 @@ class Inventory(object): self.parser = InventoryScript(filename=host_list) self.groups = self.parser.groups.values() else: - self.parser = InventoryParser(filename=host_list) - self.groups = self.parser.groups.values() - + data = file(host_list).read() + if not data.startswith("---"): + self.parser = InventoryParser(filename=host_list) + self.groups = self.parser.groups.values() + else: + self.parser = InventoryParserYaml(filename=host_list) + self.groups = self.parser.groups.values() + def _groups_from_override_hosts(self, list): # support for playbook's --override-hosts only all = Group(name='all') diff --git a/lib/ansible/inventory_parser_yaml.py b/lib/ansible/inventory_parser_yaml.py new file mode 100644 index 00000000000..0a8c282b018 --- /dev/null +++ b/lib/ansible/inventory_parser_yaml.py @@ -0,0 +1,103 @@ +# (c) 2012, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +############################################# + +import constants as C +from ansible.host import Host +from ansible.group import Group +from ansible import errors +from ansible import utils + +class InventoryParserYaml(object): + """ + Host inventory for ansible. + """ + + def __init__(self, filename=C.DEFAULT_HOST_LIST): + + fh = open(filename) + data = fh.read() + fh.close() + self._hosts = {} + self._parse(data) + + def _make_host(self, hostname): + if hostname in self._hosts: + return self._hosts[hostname] + else: + host = Host(hostname) + self._hosts[hostname] = host + return host + + # see file 'test/yaml_hosts' for syntax + + def _parse(self, data): + + all = Group('all') + ungrouped = Group('ungrouped') + all.add_child_group(ungrouped) + + self.groups = dict(all=all, ungrouped=ungrouped) + + yaml = utils.parse_yaml(data) + for item in yaml: + + if type(item) in [ str, unicode ]: + host = self._make_host(item) + ungrouped.add_host(host) + + elif type(item) == dict and 'host' in item: + host = self._make_host(item['host']) + for (k,v) in item.get('vars',{}).items(): + host.set_variable(k,v) + + elif type(item) == dict and 'group' in item: + group = Group(item['group']) + + for subresult in item.get('hosts',[]): + + if type(subresult) in [ str, unicode ]: + host = self._make_host(subresult) + group.add_host(host) + elif type(subresult) == dict: + host = self._make_host(subresult['host']) + vars = subresult.get('vars',{}) + if type(vars) == list: + for subitem in vars: + for (k,v) in subitem.items(): + host.set_variable(k,v) + elif type(vars) == dict: + for (k,v) in subresult.get('vars',{}).items(): + host.set_variable(k,v) + else: + raise errors.AnsibleError("unexpected type for variable") + group.add_host(host) + + vars = item.get('vars',{}) + if type(vars) == dict: + for (k,v) in item.get('vars',{}).items(): + group.set_variable(k,v) + elif type(vars) == list: + for subitem in vars: + if type(subitem) != dict: + raise errors.AnsibleError("expected a dictionary") + for (k,v) in subitem.items(): + group.set_variable(k,v) + + self.groups[group.name] = group + all.add_child_group(group) diff --git a/test/TestInventory.py b/test/TestInventory.py index caaaf41d158..58650cecd17 100644 --- a/test/TestInventory.py +++ b/test/TestInventory.py @@ -3,7 +3,7 @@ import unittest from ansible.inventory import Inventory from ansible.runner import Runner -from nose.plugins.skip import SkipTest +# from nose.plugins.skip import SkipTest class TestInventory(unittest.TestCase): @@ -20,6 +20,14 @@ class TestInventory(unittest.TestCase): def tearDown(self): os.chmod(self.inventory_script, 0644) + def compare(self, left, right): + left = sorted(left) + right = sorted(right) + print left + print right + assert left == right + + ### Simple inventory format tests def simple_inventory(self): @@ -167,47 +175,41 @@ class TestInventory(unittest.TestCase): ### Tests for yaml inventory file def test_yaml(self): - raise SkipTest inventory = self.yaml_inventory() hosts = inventory.list_hosts() print hosts expected_hosts=['jupiter', 'saturn', 'zeus', 'hera', 'poseidon', 'thor', 'odin', 'loki'] - assert hosts == expected_hosts + self.compare(hosts, expected_hosts) def test_yaml_all(self): - raise SkipTest inventory = self.yaml_inventory() hosts = inventory.list_hosts('all') expected_hosts=['jupiter', 'saturn', 'zeus', 'hera', 'poseidon', 'thor', 'odin', 'loki'] - assert hosts == expected_hosts + self.compare(hosts, expected_hosts) def test_yaml_norse(self): - raise SkipTest inventory = self.yaml_inventory() hosts = inventory.list_hosts("norse") expected_hosts=['thor', 'odin', 'loki'] - assert hosts == expected_hosts + self.compare(hosts, expected_hosts) def test_simple_ungrouped(self): - raise SkipTest inventory = self.yaml_inventory() hosts = inventory.list_hosts("ungrouped") - expected_hosts=['jupiter'] - assert hosts == expected_hosts + expected_hosts=['jupiter','zeus'] + self.compare(hosts, expected_hosts) def test_yaml_combined(self): - raise SkipTest inventory = self.yaml_inventory() hosts = inventory.list_hosts("norse:greek") expected_hosts=['zeus', 'hera', 'poseidon', 'thor', 'odin', 'loki'] - assert hosts == expected_hosts + self.compare(hosts, expected_hosts) def test_yaml_restrict(self): - raise SkipTest inventory = self.yaml_inventory() restricted_hosts = ['hera', 'poseidon', 'thor'] @@ -216,56 +218,53 @@ class TestInventory(unittest.TestCase): inventory.restrict_to(restricted_hosts) hosts = inventory.list_hosts("norse:greek") - assert hosts == restricted_hosts + self.compare(hosts, restricted_hosts) inventory.lift_restriction() hosts = inventory.list_hosts("norse:greek") - assert hosts == expected_hosts + self.compare(hosts, expected_hosts) def test_yaml_vars(self): - raise SkipTest inventory = self.yaml_inventory() vars = inventory.get_variables('thor') - print vars assert vars == {'group_names': ['norse'], 'hammer':True, 'inventory_hostname': 'thor'} def test_yaml_change_vars(self): - raise SkipTest inventory = self.yaml_inventory() vars = inventory.get_variables('thor') vars["hammer"] = False vars = inventory.get_variables('thor') + print vars assert vars == {'hammer':True, 'inventory_hostname': 'thor', 'group_names': ['norse']} def test_yaml_host_vars(self): - raise SkipTest inventory = self.yaml_inventory() vars = inventory.get_variables('saturn') + print vars assert vars == {'inventory_hostname': 'saturn', 'moon': 'titan', 'moon2': 'enceladus', 'group_names': ['multiple']} def test_yaml_port(self): - raise SkipTest inventory = self.yaml_inventory() vars = inventory.get_variables('hera') + print vars assert vars == {'ansible_ssh_port': 3000, 'inventory_hostname': 'hera', 'ntp_server': 'olympus.example.com', 'group_names': ['greek']} def test_yaml_multiple_groups(self): - raise SkipTest inventory = self.yaml_inventory() vars = inventory.get_variables('odin')