You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ansible/test/lib/ansible_test/_internal/core_ci.py

556 lines
17 KiB
Python

"""Access Ansible Core CI remote services."""
from __future__ import annotations
import abc
import dataclasses
import json
import os
import re
import stat
import traceback
import uuid
import time
import typing as t
from .http import (
HttpClient,
HttpResponse,
HttpError,
)
from .io import (
make_dirs,
read_text_file,
write_json_file,
write_text_file,
)
from .util import (
ApplicationError,
display,
mutex,
)
from .util_common import (
run_command,
ResultType,
)
from .config import (
EnvironmentConfig,
)
from .ci import (
get_ci_provider,
)
from .data import (
data_context,
PayloadConfig,
)
@dataclasses.dataclass(frozen=True)
class Resource(metaclass=abc.ABCMeta):
"""Base class for Ansible Core CI resources."""
@abc.abstractmethod
def as_tuple(self) -> tuple[str, str, str, str]:
"""Return the resource as a tuple of platform, version, architecture and provider."""
@abc.abstractmethod
def get_label(self) -> str:
"""Return a user-friendly label for this resource."""
@property
@abc.abstractmethod
def persist(self) -> bool:
"""True if the resource is persistent, otherwise false."""
@dataclasses.dataclass(frozen=True)
class VmResource(Resource):
"""Details needed to request a VM from Ansible Core CI."""
platform: str
version: str
architecture: str
provider: str
tag: str
def as_tuple(self) -> tuple[str, str, str, str]:
"""Return the resource as a tuple of platform, version, architecture and provider."""
return self.platform, self.version, self.architecture, self.provider
def get_label(self) -> str:
"""Return a user-friendly label for this resource."""
return f'{self.platform} {self.version} ({self.architecture}) [{self.tag}] @{self.provider}'
@property
def persist(self) -> bool:
"""True if the resource is persistent, otherwise false."""
return True
@dataclasses.dataclass(frozen=True)
class CloudResource(Resource):
"""Details needed to request cloud credentials from Ansible Core CI."""
platform: str
def as_tuple(self) -> tuple[str, str, str, str]:
"""Return the resource as a tuple of platform, version, architecture and provider."""
return self.platform, '', '', self.platform
def get_label(self) -> str:
"""Return a user-friendly label for this resource."""
return self.platform
@property
def persist(self) -> bool:
"""True if the resource is persistent, otherwise false."""
return False
class AnsibleCoreCI:
"""Client for Ansible Core CI services."""
DEFAULT_ENDPOINT = 'https://ansible-core-ci.testing.ansible.com'
def __init__(
self,
args: EnvironmentConfig,
resource: Resource,
load: bool = True,
) -> None:
self.args = args
self.resource = resource
self.platform, self.version, self.arch, self.provider = self.resource.as_tuple()
self.stage = args.remote_stage
self.client = HttpClient(args)
self.connection = None
self.instance_id = None
self.endpoint = None
self.default_endpoint = args.remote_endpoint or self.DEFAULT_ENDPOINT
self.retries = 3
self.ci_provider = get_ci_provider()
self.label = self.resource.get_label()
stripped_label = re.sub('[^A-Za-z0-9_.]+', '-', self.label).strip('-')
self.name = f"{stripped_label}-{self.stage}" # turn the label into something suitable for use as a filename
self.path = os.path.expanduser(f'~/.ansible/test/instances/{self.name}')
self.ssh_key = SshKey(args)
if self.resource.persist and load and self._load():
try:
display.info(f'Checking existing {self.label} instance using: {self._uri}', verbosity=1)
self.connection = self.get(always_raise_on=[404])
display.info(f'Loaded existing {self.label} instance.', verbosity=1)
except HttpError as ex:
if ex.status != 404:
raise
self._clear()
display.info(f'Cleared stale {self.label} instance.', verbosity=1)
self.instance_id = None
self.endpoint = None
elif not self.resource.persist:
self.instance_id = None
self.endpoint = None
self._clear()
if self.instance_id:
self.started: bool = True
else:
self.started = False
self.instance_id = str(uuid.uuid4())
self.endpoint = None
display.sensitive.add(self.instance_id)
if not self.endpoint:
self.endpoint = self.default_endpoint
@property
def available(self) -> bool:
"""Return True if Ansible Core CI is supported."""
return self.ci_provider.supports_core_ci_auth()
def start(self) -> t.Optional[dict[str, t.Any]]:
"""Start instance."""
if self.started:
display.info(f'Skipping started {self.label} instance.', verbosity=1)
return None
return self._start(self.ci_provider.prepare_core_ci_auth())
def stop(self) -> None:
"""Stop instance."""
if not self.started:
display.info(f'Skipping invalid {self.label} instance.', verbosity=1)
return
response = self.client.delete(self._uri)
if response.status_code == 404:
self._clear()
display.info(f'Cleared invalid {self.label} instance.', verbosity=1)
return
if response.status_code == 200:
self._clear()
display.info(f'Stopped running {self.label} instance.', verbosity=1)
return
raise self._create_http_error(response)
def get(self, tries: int = 3, sleep: int = 15, always_raise_on: t.Optional[list[int]] = None) -> t.Optional[InstanceConnection]:
"""Get instance connection information."""
if not self.started:
display.info(f'Skipping invalid {self.label} instance.', verbosity=1)
return None
if not always_raise_on:
always_raise_on = []
if self.connection and self.connection.running:
return self.connection
while True:
tries -= 1
response = self.client.get(self._uri)
if response.status_code == 200:
break
error = self._create_http_error(response)
if not tries or response.status_code in always_raise_on:
raise error
display.warning(f'{error}. Trying again after {sleep} seconds.')
time.sleep(sleep)
if self.args.explain:
self.connection = InstanceConnection(
running=True,
hostname='cloud.example.com',
port=12345,
username='root',
password='password' if self.platform == 'windows' else None,
)
else:
response_json = response.json()
status = response_json['status']
con = response_json.get('connection')
if con:
self.connection = InstanceConnection(
running=status == 'running',
hostname=con['hostname'],
port=int(con['port']),
username=con['username'],
password=con.get('password'),
response_json=response_json,
)
else:
self.connection = InstanceConnection(
running=status == 'running',
response_json=response_json,
)
if self.connection.password:
display.sensitive.add(str(self.connection.password))
status = 'running' if self.connection.running else 'starting'
display.info(f'The {self.label} instance is {status}.', verbosity=1)
return self.connection
def wait(self, iterations: t.Optional[int] = 90) -> None:
"""Wait for the instance to become ready."""
for _iteration in range(1, iterations):
if self.get().running:
return
time.sleep(10)
raise ApplicationError(f'Timeout waiting for {self.label} instance.')
@property
def _uri(self) -> str:
return f'{self.endpoint}/{self.stage}/{self.provider}/{self.instance_id}'
def _start(self, auth) -> dict[str, t.Any]:
"""Start instance."""
display.info(f'Initializing new {self.label} instance using: {self._uri}', verbosity=1)
data = dict(
config=dict(
platform=self.platform,
version=self.version,
architecture=self.arch,
public_key=self.ssh_key.pub_contents,
)
)
data.update(auth=auth)
headers = {
'Content-Type': 'application/json',
}
response = self._start_endpoint(data, headers)
self.started = True
self._save()
display.info(f'Started {self.label} instance.', verbosity=1)
if self.args.explain:
return {}
return response.json()
def _start_endpoint(self, data: dict[str, t.Any], headers: dict[str, str]) -> HttpResponse:
tries = self.retries
sleep = 15
while True:
tries -= 1
response = self.client.put(self._uri, data=json.dumps(data), headers=headers)
if response.status_code == 200:
return response
error = self._create_http_error(response)
if response.status_code == 503:
raise error
if not tries:
raise error
display.warning(f'{error}. Trying again after {sleep} seconds.')
time.sleep(sleep)
def _clear(self) -> None:
"""Clear instance information."""
try:
self.connection = None
os.remove(self.path)
except FileNotFoundError:
pass
def _load(self) -> bool:
"""Load instance information."""
try:
data = read_text_file(self.path)
except FileNotFoundError:
return False
if not data.startswith('{'):
return False # legacy format
config = json.loads(data)
return self.load(config)
def load(self, config: dict[str, str]) -> bool:
"""Load the instance from the provided dictionary."""
self.instance_id = str(config['instance_id'])
self.endpoint = config['endpoint']
self.started = True
display.sensitive.add(self.instance_id)
return True
def _save(self) -> None:
"""Save instance information."""
if self.args.explain:
return
config = self.save()
write_json_file(self.path, config, create_directories=True)
def save(self) -> dict[str, str]:
"""Save instance details and return as a dictionary."""
return dict(
label=self.resource.get_label(),
instance_id=self.instance_id,
endpoint=self.endpoint,
)
@staticmethod
def _create_http_error(response: HttpResponse) -> ApplicationError:
"""Return an exception created from the given HTTP response."""
response_json = response.json()
stack_trace = ''
if 'message' in response_json:
message = response_json['message']
elif 'errorMessage' in response_json:
message = response_json['errorMessage'].strip()
if 'stackTrace' in response_json:
traceback_lines = response_json['stackTrace']
# AWS Lambda on Python 2.7 returns a list of tuples
# AWS Lambda on Python 3.7 returns a list of strings
if traceback_lines and isinstance(traceback_lines[0], list):
traceback_lines = traceback.format_list(traceback_lines)
trace = '\n'.join([x.rstrip() for x in traceback_lines])
stack_trace = f'\nTraceback (from remote server):\n{trace}'
else:
message = str(response_json)
return CoreHttpError(response.status_code, message, stack_trace)
class CoreHttpError(HttpError):
"""HTTP response as an error."""
def __init__(self, status: int, remote_message: str, remote_stack_trace: str) -> None:
super().__init__(status, f'{remote_message}{remote_stack_trace}')
self.remote_message = remote_message
self.remote_stack_trace = remote_stack_trace
class SshKey:
"""Container for SSH key used to connect to remote instances."""
KEY_TYPE = 'rsa' # RSA is used to maintain compatibility with paramiko and EC2
KEY_NAME = f'id_{KEY_TYPE}'
PUB_NAME = f'{KEY_NAME}.pub'
@mutex
def __init__(self, args: EnvironmentConfig) -> None:
key_pair = self.get_key_pair()
if not key_pair:
key_pair = self.generate_key_pair(args)
key, pub = key_pair
key_dst, pub_dst = self.get_in_tree_key_pair_paths()
def ssh_key_callback(payload_config: PayloadConfig) -> None:
"""
Add the SSH keys to the payload file list.
They are either outside the source tree or in the cache dir which is ignored by default.
"""
files = payload_config.files
permissions = payload_config.permissions
files.append((key, os.path.relpath(key_dst, data_context().content.root)))
files.append((pub, os.path.relpath(pub_dst, data_context().content.root)))
permissions[os.path.relpath(key_dst, data_context().content.root)] = stat.S_IRUSR | stat.S_IWUSR
data_context().register_payload_callback(ssh_key_callback)
self.key, self.pub = key, pub
if args.explain:
self.pub_contents = None
self.key_contents = None
else:
self.pub_contents = read_text_file(self.pub).strip()
self.key_contents = read_text_file(self.key).strip()
@staticmethod
def get_relative_in_tree_private_key_path() -> str:
"""Return the ansible-test SSH private key path relative to the content tree."""
temp_dir = ResultType.TMP.relative_path
key = os.path.join(temp_dir, SshKey.KEY_NAME)
return key
def get_in_tree_key_pair_paths(self) -> t.Optional[tuple[str, str]]:
"""Return the ansible-test SSH key pair paths from the content tree."""
temp_dir = ResultType.TMP.path
key = os.path.join(temp_dir, self.KEY_NAME)
pub = os.path.join(temp_dir, self.PUB_NAME)
return key, pub
def get_source_key_pair_paths(self) -> t.Optional[tuple[str, str]]:
"""Return the ansible-test SSH key pair paths for the current user."""
base_dir = os.path.expanduser('~/.ansible/test/')
key = os.path.join(base_dir, self.KEY_NAME)
pub = os.path.join(base_dir, self.PUB_NAME)
return key, pub
def get_key_pair(self) -> t.Optional[tuple[str, str]]:
"""Return the ansible-test SSH key pair paths if present, otherwise return None."""
key, pub = self.get_in_tree_key_pair_paths()
if os.path.isfile(key) and os.path.isfile(pub):
return key, pub
key, pub = self.get_source_key_pair_paths()
if os.path.isfile(key) and os.path.isfile(pub):
return key, pub
return None
def generate_key_pair(self, args: EnvironmentConfig) -> tuple[str, str]:
"""Generate an SSH key pair for use by all ansible-test invocations for the current user."""
key, pub = self.get_source_key_pair_paths()
if not args.explain:
make_dirs(os.path.dirname(key))
if not os.path.isfile(key) or not os.path.isfile(pub):
run_command(args, ['ssh-keygen', '-m', 'PEM', '-q', '-t', self.KEY_TYPE, '-N', '', '-f', key], capture=True)
if args.explain:
return key, pub
# newer ssh-keygen PEM output (such as on RHEL 8.1) is not recognized by paramiko
key_contents = read_text_file(key)
key_contents = re.sub(r'(BEGIN|END) PRIVATE KEY', r'\1 RSA PRIVATE KEY', key_contents)
write_text_file(key, key_contents)
return key, pub
class InstanceConnection:
"""Container for remote instance status and connection details."""
def __init__(
self,
running: bool,
hostname: t.Optional[str] = None,
port: t.Optional[int] = None,
username: t.Optional[str] = None,
password: t.Optional[str] = None,
response_json: t.Optional[dict[str, t.Any]] = None,
) -> None:
self.running = running
self.hostname = hostname
self.port = port
self.username = username
self.password = password
self.response_json = response_json or {}
def __str__(self):
if self.password:
return f'{self.hostname}:{self.port} [{self.username}:{self.password}]'
return f'{self.hostname}:{self.port} [{self.username}]'