diff --git a/changelogs/fragments/galaxy-reauth-error-handling.yml b/changelogs/fragments/galaxy-reauth-error-handling.yml new file mode 100644 index 00000000000..35c169b8e0b --- /dev/null +++ b/changelogs/fragments/galaxy-reauth-error-handling.yml @@ -0,0 +1,2 @@ +minor_changes: +- ansible-galaxy - Handle authentication errors and token expiration diff --git a/lib/ansible/galaxy/token.py b/lib/ansible/galaxy/token.py index 183e2af109e..573d1b3a56c 100644 --- a/lib/ansible/galaxy/token.py +++ b/lib/ansible/galaxy/token.py @@ -21,11 +21,14 @@ from __future__ import annotations import base64 -import os import json +import os +import time from stat import S_IRUSR, S_IWUSR +from urllib.error import HTTPError from ansible import constants as C +from ansible.galaxy.api import GalaxyError from ansible.galaxy.user_agent import user_agent from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text from ansible.module_utils.common.yaml import yaml_dump, yaml_load @@ -57,12 +60,16 @@ class KeycloakToken(object): self.client_id = client_id if self.client_id is None: self.client_id = 'cloud-services' + self._expiration = None def _form_payload(self): return 'grant_type=refresh_token&client_id=%s&refresh_token=%s' % (self.client_id, self.access_token) def get(self): + if self._expiration and time.time() >= self._expiration: + self._token = None + if self._token: return self._token @@ -76,15 +83,20 @@ class KeycloakToken(object): # or 'azp' (Authorized party - the party to which the ID Token was issued) payload = self._form_payload() - resp = open_url(to_native(self.auth_url), - data=payload, - validate_certs=self.validate_certs, - method='POST', - http_agent=user_agent()) + try: + resp = open_url(to_native(self.auth_url), + data=payload, + validate_certs=self.validate_certs, + method='POST', + http_agent=user_agent()) + except HTTPError as e: + raise GalaxyError(e, 'Unable to get access token') - # TODO: handle auth errors + data = json.load(resp) - data = json.loads(to_text(resp.read(), errors='surrogate_or_strict')) + # So that we have a buffer, expire the token in ~2/3 the given value + expires_in = data['expires_in'] // 3 * 2 + self._expiration = time.time() + expires_in # - extract 'access_token' self._token = data.get('access_token')