diff --git a/lib/ansible/plugins/inventory/scaleway.py b/lib/ansible/plugins/inventory/scaleway.py index c2be621b2e7..37cb3940284 100644 --- a/lib/ansible/plugins/inventory/scaleway.py +++ b/lib/ansible/plugins/inventory/scaleway.py @@ -35,6 +35,17 @@ DOCUMENTATION = ''' - name: SCW_TOKEN - name: SCW_API_KEY - name: SCW_OAUTH_TOKEN + hostnames: + description: List of preference about what to use as an hostname. + type: list + default: + - public_ipv4 + choices: + - public_ipv4 + - private_ipv4 + - public_ipv6 + - hostname + - id ''' EXAMPLES = ''' @@ -47,10 +58,11 @@ regions: - par1 tags: - foobar +hostnames: + - public_ipv4 ''' import json -import os from ansible.errors import AnsibleError from ansible.plugins.inventory import BaseInventoryPlugin @@ -82,26 +94,68 @@ def _build_server_url(api_endpoint): return "/".join([api_endpoint, "servers"]) +def extract_public_ipv4(server_info): + try: + return server_info["public_ip"]["address"] + except (KeyError, TypeError): + return None + + +def extract_private_ipv4(server_info): + return server_info["private_ip"] + + +def extract_hostname(server_info): + return server_info["hostname"] + + +def extract_server_id(server_info): + return server_info["id"] + + +def extract_public_ipv6(server_info): + try: + return server_info["ipv6"]["address"] + except (KeyError, TypeError): + return None + + +extractors = { + "public_ipv4": extract_public_ipv4, + "private_ipv4": extract_private_ipv4, + "public_ipv6": extract_public_ipv6, + "hostname": extract_hostname, + "id": extract_server_id +} + + class InventoryModule(BaseInventoryPlugin): NAME = 'scaleway' def verify_file(self, path): return "scaleway" in path - def _fill_host_variables(self, server_id, server_info): + def _fill_host_variables(self, host, server_info): targeted_attributes = ( "arch", "commercial_type", + "id", "organization", "state", "hostname", "state" ) for attribute in targeted_attributes: - self.inventory.set_variable(server_id, attribute, server_info[attribute]) + self.inventory.set_variable(host, attribute, server_info[attribute]) + + self.inventory.set_variable(host, "tags", server_info["tags"]) + + if extract_public_ipv6(server_info=server_info): + self.inventory.set_variable(host, "public_ipv6", extract_public_ipv6(server_info=server_info)) - self.inventory.set_variable(server_id, "tags", server_info["tags"]) - self.inventory.set_variable(server_id, "ipv4", server_info["public_ip"]["address"]) + if extract_public_ipv4(server_info=server_info): + self.inventory.set_variable(host, "public_ipv4", extract_public_ipv4(server_info=server_info)) + self.inventory.set_variable(host, "ansible_host", extract_public_ipv4(server_info=server_info)) def _get_zones(self, config_zones): return set(SCALEWAY_LOCATION.keys()).intersection(config_zones) @@ -121,22 +175,36 @@ class InventoryModule(BaseInventoryPlugin): else: return matching_tags.union((server_zone,)) - def do_zone_inventory(self, zone, token, tags): + def _filter_host(self, host_infos, hostname_preferences): + + for pref in hostname_preferences: + if extractors[pref](host_infos): + return extractors[pref](host_infos) + + return None + + def do_zone_inventory(self, zone, token, tags, hostname_preferences): self.inventory.add_group(zone) zone_info = SCALEWAY_LOCATION[zone] url = _build_server_url(zone_info["api_endpoint"]) - all_servers = _fetch_information(url=url, token=token) + raw_zone_hosts_infos = _fetch_information(url=url, token=token) + + for host_infos in raw_zone_hosts_infos: + + hostname = self._filter_host(host_infos=host_infos, + hostname_preferences=hostname_preferences) - for server_info in all_servers: + # No suitable hostname were found in the attributes and the host won't be in the inventory + if not hostname: + continue - groups = self.match_groups(server_info, tags) - server_id = server_info["id"] + groups = self.match_groups(host_infos, tags) for group in groups: self.inventory.add_group(group=group) - self.inventory.add_host(group=group, host=server_id) - self._fill_host_variables(server_id=server_id, server_info=server_info) + self.inventory.add_host(group=group, host=hostname) + self._fill_host_variables(host=hostname, server_info=host_infos) def parse(self, inventory, loader, path, cache=True): super(InventoryModule, self).parse(inventory, loader, path) @@ -145,6 +213,7 @@ class InventoryModule(BaseInventoryPlugin): config_zones = self.get_option("regions") tags = self.get_option("tags") token = self.get_option("oauth_token") + hostname_preference = self.get_option("hostnames") for zone in self._get_zones(config_zones): - self.do_zone_inventory(zone=zone, token=token, tags=tags) + self.do_zone_inventory(zone=zone, token=token, tags=tags, hostname_preferences=hostname_preference)