"""Access Ansible Core CI remote services.""" from __future__ import (absolute_import, division, print_function) __metaclass__ = type import json import os import re import traceback import uuid import errno import time from . import types as t from .http import ( HttpClient, HttpResponse, HttpError, ) from .io import ( make_dirs, read_text_file, write_json_file, write_text_file, ) from .util import ( ApplicationError, display, ANSIBLE_TEST_DATA_ROOT, ) from .util_common import ( run_command, ResultType, ) from .config import ( EnvironmentConfig, ) from .ci import ( AuthContext, get_ci_provider, ) from .data import ( data_context, ) class AnsibleCoreCI: """Client for Ansible Core CI services.""" DEFAULT_ENDPOINT = 'https://ansible-core-ci.testing.ansible.com' # Assign a default provider for each VM platform supported. # This is used to determine the provider from the platform when no provider is specified. # The keys here also serve as the list of providers which users can select from the command line. # # Entries can take one of two formats: # {platform} # {platform} arch={arch} # # Entries with an arch are only used as a default if the value for --remote-arch matches the {arch} specified. # This allows arch specific defaults to be distinct from the default used when no arch is specified. PROVIDERS = dict( aws=( 'freebsd', 'ios', 'rhel', 'vyos', 'windows', ), azure=( ), ibmps=( 'aix', ), parallels=( 'macos', 'osx', ), ) # Currently ansible-core-ci has no concept of arch selection. This effectively means each provider only supports one arch. # The list below identifies which platforms accept an arch, and which one. These platforms can only be used with the specified arch. PROVIDER_ARCHES = dict( ) def __init__(self, args, platform, version, stage='prod', persist=True, load=True, provider=None, arch=None, internal=False): """ :type args: EnvironmentConfig :type platform: str :type version: str :type stage: str :type persist: bool :type load: bool :type provider: str | None :type arch: str | None :type internal: bool """ self.args = args self.arch = arch self.platform = platform self.version = version self.stage = stage self.client = HttpClient(args) self.connection = None self.instance_id = None self.endpoint = None self.default_endpoint = args.remote_endpoint or self.DEFAULT_ENDPOINT self.retries = 3 self.ci_provider = get_ci_provider() self.auth_context = AuthContext() if self.arch: self.name = '%s-%s-%s' % (self.arch, self.platform, self.version) else: self.name = '%s-%s' % (self.platform, self.version) if provider: # override default provider selection (not all combinations are valid) self.provider = provider else: self.provider = None for candidate, platforms in self.PROVIDERS.items(): choices = [ platform, '%s arch=%s' % (platform, arch), ] if any(choice in platforms for choice in choices): # assign default provider based on platform self.provider = candidate break # If a provider has been selected, make sure the correct arch (or none) has been selected. if self.provider: required_arch = self.PROVIDER_ARCHES.get(self.provider) if self.arch != required_arch: if required_arch: if self.arch: raise ApplicationError('Provider "%s" requires the "%s" arch instead of "%s".' % (self.provider, required_arch, self.arch)) raise ApplicationError('Provider "%s" requires the "%s" arch.' % (self.provider, required_arch)) raise ApplicationError('Provider "%s" does not support specification of an arch.' % self.provider) self.path = os.path.expanduser('~/.ansible/test/instances/%s-%s-%s' % (self.name, self.provider, self.stage)) if self.provider not in self.PROVIDERS and not internal: if self.arch: raise ApplicationError('Provider not detected for platform "%s" on arch "%s".' % (self.platform, self.arch)) raise ApplicationError('Provider not detected for platform "%s" with no arch specified.' % self.platform) self.ssh_key = SshKey(args) if persist and load and self._load(): try: display.info('Checking existing %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) self.connection = self.get(always_raise_on=[404]) display.info('Loaded existing %s/%s from: %s' % (self.platform, self.version, self._uri), verbosity=1) except HttpError as ex: if ex.status != 404: raise self._clear() display.info('Cleared stale %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) self.instance_id = None self.endpoint = None elif not persist: self.instance_id = None self.endpoint = None self._clear() if self.instance_id: self.started = True else: self.started = False self.instance_id = str(uuid.uuid4()) self.endpoint = None display.sensitive.add(self.instance_id) if not self.endpoint: self.endpoint = self.default_endpoint @property def available(self): """Return True if Ansible Core CI is supported.""" return self.ci_provider.supports_core_ci_auth(self.auth_context) def start(self): """Start instance.""" if self.started: display.info('Skipping started %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) return None return self._start(self.ci_provider.prepare_core_ci_auth(self.auth_context)) def stop(self): """Stop instance.""" if not self.started: display.info('Skipping invalid %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) return response = self.client.delete(self._uri) if response.status_code == 404: self._clear() display.info('Cleared invalid %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) return if response.status_code == 200: self._clear() display.info('Stopped running %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) return raise self._create_http_error(response) def get(self, tries=3, sleep=15, always_raise_on=None): """ Get instance connection information. :type tries: int :type sleep: int :type always_raise_on: list[int] | None :rtype: InstanceConnection """ if not self.started: display.info('Skipping invalid %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) return None if not always_raise_on: always_raise_on = [] if self.connection and self.connection.running: return self.connection while True: tries -= 1 response = self.client.get(self._uri) if response.status_code == 200: break error = self._create_http_error(response) if not tries or response.status_code in always_raise_on: raise error display.warning('%s. Trying again after %d seconds.' % (error, sleep)) time.sleep(sleep) if self.args.explain: self.connection = InstanceConnection( running=True, hostname='cloud.example.com', port=12345, username='username', password='password' if self.platform == 'windows' else None, ) else: response_json = response.json() status = response_json['status'] con = response_json.get('connection') if con: self.connection = InstanceConnection( running=status == 'running', hostname=con['hostname'], port=int(con['port']), username=con['username'], password=con.get('password'), response_json=response_json, ) else: self.connection = InstanceConnection( running=status == 'running', response_json=response_json, ) if self.connection.password: display.sensitive.add(str(self.connection.password)) status = 'running' if self.connection.running else 'starting' display.info('Status update: %s/%s on instance %s is %s.' % (self.platform, self.version, self.instance_id, status), verbosity=1) return self.connection def wait(self, iterations=90): # type: (t.Optional[int]) -> None """Wait for the instance to become ready.""" for _iteration in range(1, iterations): if self.get().running: return time.sleep(10) raise ApplicationError('Timeout waiting for %s/%s instance %s.' % (self.platform, self.version, self.instance_id)) @property def _uri(self): return '%s/%s/%s/%s' % (self.endpoint, self.stage, self.provider, self.instance_id) def _start(self, auth): """Start instance.""" display.info('Initializing new %s/%s instance %s.' % (self.platform, self.version, self.instance_id), verbosity=1) if self.platform == 'windows': winrm_config = read_text_file(os.path.join(ANSIBLE_TEST_DATA_ROOT, 'setup', 'ConfigureRemotingForAnsible.ps1')) else: winrm_config = None data = dict( config=dict( platform=self.platform, version=self.version, public_key=self.ssh_key.pub_contents, query=False, winrm_config=winrm_config, ) ) data.update(dict(auth=auth)) headers = { 'Content-Type': 'application/json', } response = self._start_endpoint(data, headers) self.started = True self._save() display.info('Started %s/%s from: %s' % (self.platform, self.version, self._uri), verbosity=1) if self.args.explain: return {} return response.json() def _start_endpoint(self, data, headers): """ :type data: dict[str, any] :type headers: dict[str, str] :rtype: HttpResponse """ tries = self.retries sleep = 15 display.info('Trying endpoint: %s' % self.endpoint, verbosity=1) while True: tries -= 1 response = self.client.put(self._uri, data=json.dumps(data), headers=headers) if response.status_code == 200: return response error = self._create_http_error(response) if response.status_code == 503: raise error if not tries: raise error display.warning('%s. Trying again after %d seconds.' % (error, sleep)) time.sleep(sleep) def _clear(self): """Clear instance information.""" try: self.connection = None os.remove(self.path) except OSError as ex: if ex.errno != errno.ENOENT: raise def _load(self): """Load instance information.""" try: data = read_text_file(self.path) except IOError as ex: if ex.errno != errno.ENOENT: raise return False if not data.startswith('{'): return False # legacy format config = json.loads(data) return self.load(config) def load(self, config): """ :type config: dict[str, str] :rtype: bool """ self.instance_id = str(config['instance_id']) self.endpoint = config['endpoint'] self.started = True display.sensitive.add(self.instance_id) return True def _save(self): """Save instance information.""" if self.args.explain: return config = self.save() write_json_file(self.path, config, create_directories=True) def save(self): """ :rtype: dict[str, str] """ return dict( platform_version='%s/%s' % (self.platform, self.version), instance_id=self.instance_id, endpoint=self.endpoint, ) @staticmethod def _create_http_error(response): """ :type response: HttpResponse :rtype: ApplicationError """ response_json = response.json() stack_trace = '' if 'message' in response_json: message = response_json['message'] elif 'errorMessage' in response_json: message = response_json['errorMessage'].strip() if 'stackTrace' in response_json: traceback_lines = response_json['stackTrace'] # AWS Lambda on Python 2.7 returns a list of tuples # AWS Lambda on Python 3.7 returns a list of strings if traceback_lines and isinstance(traceback_lines[0], list): traceback_lines = traceback.format_list(traceback_lines) trace = '\n'.join([x.rstrip() for x in traceback_lines]) stack_trace = ('\nTraceback (from remote server):\n%s' % trace) else: message = str(response_json) return CoreHttpError(response.status_code, message, stack_trace) class CoreHttpError(HttpError): """HTTP response as an error.""" def __init__(self, status, remote_message, remote_stack_trace): """ :type status: int :type remote_message: str :type remote_stack_trace: str """ super(CoreHttpError, self).__init__(status, '%s%s' % (remote_message, remote_stack_trace)) self.remote_message = remote_message self.remote_stack_trace = remote_stack_trace class SshKey: """Container for SSH key used to connect to remote instances.""" KEY_TYPE = 'rsa' # RSA is used to maintain compatibility with paramiko and EC2 KEY_NAME = 'id_%s' % KEY_TYPE PUB_NAME = '%s.pub' % KEY_NAME def __init__(self, args): """ :type args: EnvironmentConfig """ key_pair = self.get_key_pair() if not key_pair: key_pair = self.generate_key_pair(args) key, pub = key_pair key_dst, pub_dst = self.get_in_tree_key_pair_paths() def ssh_key_callback(files): # type: (t.List[t.Tuple[str, str]]) -> None """ Add the SSH keys to the payload file list. They are either outside the source tree or in the cache dir which is ignored by default. """ files.append((key, os.path.relpath(key_dst, data_context().content.root))) files.append((pub, os.path.relpath(pub_dst, data_context().content.root))) data_context().register_payload_callback(ssh_key_callback) self.key, self.pub = key, pub if args.explain: self.pub_contents = None self.key_contents = None else: self.pub_contents = read_text_file(self.pub).strip() self.key_contents = read_text_file(self.key).strip() def get_in_tree_key_pair_paths(self): # type: () -> t.Optional[t.Tuple[str, str]] """Return the ansible-test SSH key pair paths from the content tree.""" temp_dir = ResultType.TMP.path key = os.path.join(temp_dir, self.KEY_NAME) pub = os.path.join(temp_dir, self.PUB_NAME) return key, pub def get_source_key_pair_paths(self): # type: () -> t.Optional[t.Tuple[str, str]] """Return the ansible-test SSH key pair paths for the current user.""" base_dir = os.path.expanduser('~/.ansible/test/') key = os.path.join(base_dir, self.KEY_NAME) pub = os.path.join(base_dir, self.PUB_NAME) return key, pub def get_key_pair(self): # type: () -> t.Optional[t.Tuple[str, str]] """Return the ansible-test SSH key pair paths if present, otherwise return None.""" key, pub = self.get_in_tree_key_pair_paths() if os.path.isfile(key) and os.path.isfile(pub): return key, pub key, pub = self.get_source_key_pair_paths() if os.path.isfile(key) and os.path.isfile(pub): return key, pub return None def generate_key_pair(self, args): # type: (EnvironmentConfig) -> t.Tuple[str, str] """Generate an SSH key pair for use by all ansible-test invocations for the current user.""" key, pub = self.get_source_key_pair_paths() if not args.explain: make_dirs(os.path.dirname(key)) if not os.path.isfile(key) or not os.path.isfile(pub): run_command(args, ['ssh-keygen', '-m', 'PEM', '-q', '-t', self.KEY_TYPE, '-N', '', '-f', key]) if args.explain: return key, pub # newer ssh-keygen PEM output (such as on RHEL 8.1) is not recognized by paramiko key_contents = read_text_file(key) key_contents = re.sub(r'(BEGIN|END) PRIVATE KEY', r'\1 RSA PRIVATE KEY', key_contents) write_text_file(key, key_contents) return key, pub class InstanceConnection: """Container for remote instance status and connection details.""" def __init__(self, running, # type: bool hostname=None, # type: t.Optional[str] port=None, # type: t.Optional[int] username=None, # type: t.Optional[str] password=None, # type: t.Optional[str] response_json=None, # type: t.Optional[t.Dict[str, t.Any]] ): # type: (...) -> None self.running = running self.hostname = hostname self.port = port self.username = username self.password = password self.response_json = response_json or {} def __str__(self): if self.password: return '%s:%s [%s:%s]' % (self.hostname, self.port, self.username, self.password) return '%s:%s [%s]' % (self.hostname, self.port, self.username)