Scaleway inventory plugin: small improvements (#41642)

* scaleway inventory: token is mandatory

* scaleway inventory: mention exception in error

* scaleway inventory: remove print statement

* scaleway inventory: options don't need to be attr

* scaleway inventory: remove unused attr

(cherry picked from commit 3e6c76fc2e)
pull/41880/head
Pilou 7 years ago committed by Matt Clay
parent 8ad8db42e3
commit d95cf17dd4

@ -24,6 +24,7 @@ DOCUMENTATION = '''
description: Filter results on a specific tag description: Filter results on a specific tag
type: list type: list
oauth_token: oauth_token:
required: True
description: Scaleway OAuth token. description: Scaleway OAuth token.
env: env:
# in order of precedence # in order of precedence
@ -51,6 +52,7 @@ from ansible.errors import AnsibleError
from ansible.plugins.inventory import BaseInventoryPlugin from ansible.plugins.inventory import BaseInventoryPlugin
from ansible.module_utils.scaleway import SCALEWAY_LOCATION from ansible.module_utils.scaleway import SCALEWAY_LOCATION
from ansible.module_utils.urls import open_url from ansible.module_utils.urls import open_url
from ansible.module_utils._text import to_native
def _fetch_information(token, url): def _fetch_information(token, url):
@ -58,8 +60,8 @@ def _fetch_information(token, url):
response = open_url(url, response = open_url(url,
headers={'X-Auth-Token': token, headers={'X-Auth-Token': token,
'Content-type': 'application/json'}) 'Content-type': 'application/json'})
except Exception: except Exception as e:
raise AnsibleError("Error while fetching %s" % url) raise AnsibleError("Error while fetching %s: %s" % (url, to_native(e)))
try: try:
raw_json = json.loads(response.read()) raw_json = json.loads(response.read())
@ -79,12 +81,6 @@ def _build_server_url(api_endpoint):
class InventoryModule(BaseInventoryPlugin): class InventoryModule(BaseInventoryPlugin):
NAME = 'scaleway' NAME = 'scaleway'
def __init__(self):
super(InventoryModule, self).__init__()
self.token = None
self.config_data = None
def verify_file(self, path): def verify_file(self, path):
return "scaleway" in path return "scaleway" in path
@ -103,39 +99,34 @@ class InventoryModule(BaseInventoryPlugin):
self.inventory.set_variable(server_id, "tags", server_info["tags"]) self.inventory.set_variable(server_id, "tags", server_info["tags"])
self.inventory.set_variable(server_id, "ipv4", server_info["public_ip"]["address"]) self.inventory.set_variable(server_id, "ipv4", server_info["public_ip"]["address"])
def _get_zones(self): def _get_zones(self, config_zones):
config_zones = self.get_option("regions")
return set(SCALEWAY_LOCATION.keys()).intersection(config_zones) return set(SCALEWAY_LOCATION.keys()).intersection(config_zones)
def _get_tags(self): def match_groups(self, server_info, tags):
return self.get_option("tags")
def match_groups(self, server_info):
server_zone = server_info["location"]["zone_id"] server_zone = server_info["location"]["zone_id"]
server_tags = server_info["tags"] server_tags = server_info["tags"]
# If no filtering is defined, all tags are valid groups # If no filtering is defined, all tags are valid groups
if self._get_tags() is None: if tags is None:
return set(server_tags).union((server_zone,)) return set(server_tags).union((server_zone,))
matching_tags = set(server_tags).intersection(self._get_tags()) matching_tags = set(server_tags).intersection(tags)
if not matching_tags: if not matching_tags:
return set() return set()
else: else:
return matching_tags.union((server_zone,)) return matching_tags.union((server_zone,))
def do_zone_inventory(self, zone): def do_zone_inventory(self, zone, token, tags):
self.inventory.add_group(zone) self.inventory.add_group(zone)
zone_info = SCALEWAY_LOCATION[zone] zone_info = SCALEWAY_LOCATION[zone]
url = _build_server_url(zone_info["api_endpoint"]) url = _build_server_url(zone_info["api_endpoint"])
all_servers = _fetch_information(url=url, token=self.token) all_servers = _fetch_information(url=url, token=token)
for server_info in all_servers: for server_info in all_servers:
groups = self.match_groups(server_info) groups = self.match_groups(server_info, tags)
print(groups)
server_id = server_info["id"] server_id = server_info["id"]
for group in groups: for group in groups:
@ -145,8 +136,11 @@ class InventoryModule(BaseInventoryPlugin):
def parse(self, inventory, loader, path, cache=True): def parse(self, inventory, loader, path, cache=True):
super(InventoryModule, self).parse(inventory, loader, path) super(InventoryModule, self).parse(inventory, loader, path)
self.config_data = self._read_config_data(path=path) self._read_config_data(path=path)
self.token = self.get_option("oauth_token")
config_zones = self.get_option("regions")
tags = self.get_option("tags")
token = self.get_option("oauth_token")
for zone in self._get_zones(): for zone in self._get_zones(config_zones):
self.do_zone_inventory(zone=zone) self.do_zone_inventory(zone=zone, token=token, tags=tags)

Loading…
Cancel
Save