You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ansible/test/lib/ansible_test/_internal/connections.py

259 lines
7.9 KiB
Python

"""Connection abstraction for interacting with test hosts."""
from __future__ import annotations
import abc
import shlex
import tempfile
import typing as t
from .io import (
read_text_file,
)
from .config import (
EnvironmentConfig,
)
from .util import (
Display,
OutputStream,
SubprocessError,
retry,
)
from .util_common import (
run_command,
)
from .docker_util import (
DockerInspect,
docker_exec,
docker_inspect,
docker_network_disconnect,
)
from .ssh import (
SshConnectionDetail,
ssh_options_to_list,
)
from .become import (
Become,
)
class Connection(metaclass=abc.ABCMeta):
"""Base class for connecting to a host."""
@abc.abstractmethod
def run(self,
command: list[str],
capture: bool,
interactive: bool = False,
data: t.Optional[str] = None,
stdin: t.Optional[t.IO[bytes]] = None,
stdout: t.Optional[t.IO[bytes]] = None,
output_stream: t.Optional[OutputStream] = None,
) -> tuple[t.Optional[str], t.Optional[str]]:
"""Run the specified command and return the result."""
def extract_archive(self,
chdir: str,
src: t.IO[bytes],
):
"""Extract the given archive file stream in the specified directory."""
tar_cmd = ['tar', 'oxzf', '-', '-C', chdir]
retry(lambda: self.run(tar_cmd, stdin=src, capture=True))
def create_archive(self,
chdir: str,
name: str,
dst: t.IO[bytes],
exclude: t.Optional[str] = None,
):
"""Create the specified archive file stream from the specified directory, including the given name and optionally excluding the given name."""
tar_cmd = ['tar', 'cf', '-', '-C', chdir]
gzip_cmd = ['gzip']
if exclude:
tar_cmd += ['--exclude', exclude]
tar_cmd.append(name)
# Using gzip to compress the archive allows this to work on all POSIX systems we support.
commands = [tar_cmd, gzip_cmd]
sh_cmd = ['sh', '-c', ' | '.join(shlex.join(command) for command in commands)]
retry(lambda: self.run(sh_cmd, stdout=dst, capture=True))
class LocalConnection(Connection):
"""Connect to localhost."""
def __init__(self, args: EnvironmentConfig) -> None:
self.args = args
def run(self,
command: list[str],
capture: bool,
interactive: bool = False,
data: t.Optional[str] = None,
stdin: t.Optional[t.IO[bytes]] = None,
stdout: t.Optional[t.IO[bytes]] = None,
output_stream: t.Optional[OutputStream] = None,
) -> tuple[t.Optional[str], t.Optional[str]]:
"""Run the specified command and return the result."""
return run_command(
args=self.args,
cmd=command,
capture=capture,
data=data,
stdin=stdin,
stdout=stdout,
interactive=interactive,
output_stream=output_stream,
)
class SshConnection(Connection):
"""Connect to a host using SSH."""
def __init__(self, args: EnvironmentConfig, settings: SshConnectionDetail, become: t.Optional[Become] = None) -> None:
self.args = args
self.settings = settings
self.become = become
self.options = ['-i', settings.identity_file]
ssh_options: dict[str, t.Union[int, str]] = dict(
BatchMode='yes',
StrictHostKeyChecking='no',
UserKnownHostsFile='/dev/null',
ServerAliveInterval=15,
ServerAliveCountMax=4,
)
ssh_options.update(settings.options)
self.options.extend(ssh_options_to_list(ssh_options))
def run(self,
command: list[str],
capture: bool,
interactive: bool = False,
data: t.Optional[str] = None,
stdin: t.Optional[t.IO[bytes]] = None,
stdout: t.Optional[t.IO[bytes]] = None,
output_stream: t.Optional[OutputStream] = None,
) -> tuple[t.Optional[str], t.Optional[str]]:
"""Run the specified command and return the result."""
options = list(self.options)
if self.become:
command = self.become.prepare_command(command)
options.append('-q')
if interactive:
options.append('-tt')
with tempfile.NamedTemporaryFile(prefix='ansible-test-ssh-debug-', suffix='.log') as ssh_logfile:
options.extend(['-vvv', '-E', ssh_logfile.name])
if self.settings.port:
options.extend(['-p', str(self.settings.port)])
options.append(f'{self.settings.user}@{self.settings.host}')
options.append(shlex.join(command))
def error_callback(ex: SubprocessError) -> None:
"""Error handler."""
self.capture_log_details(ssh_logfile.name, ex)
return run_command(
args=self.args,
cmd=['ssh'] + options,
capture=capture,
data=data,
stdin=stdin,
stdout=stdout,
interactive=interactive,
output_stream=output_stream,
error_callback=error_callback,
)
@staticmethod
def capture_log_details(path: str, ex: SubprocessError) -> None:
"""Read the specified SSH debug log and add relevant details to the provided exception."""
if ex.status != 255:
return
markers = [
'debug1: Connection Established',
'debug1: Authentication successful',
'debug1: Entering interactive session',
'debug1: Sending command',
'debug2: PTY allocation request accepted',
'debug2: exec request accepted',
]
file_contents = read_text_file(path)
messages = []
for line in reversed(file_contents.splitlines()):
messages.append(line)
if any(line.startswith(marker) for marker in markers):
break
message = '\n'.join(reversed(messages))
ex.message += '>>> SSH Debug Output\n'
ex.message += '%s%s\n' % (message.strip(), Display.clear)
class DockerConnection(Connection):
"""Connect to a host using Docker."""
def __init__(self, args: EnvironmentConfig, container_id: str, user: t.Optional[str] = None) -> None:
self.args = args
self.container_id = container_id
self.user: t.Optional[str] = user
def run(self,
command: list[str],
capture: bool,
interactive: bool = False,
data: t.Optional[str] = None,
stdin: t.Optional[t.IO[bytes]] = None,
stdout: t.Optional[t.IO[bytes]] = None,
output_stream: t.Optional[OutputStream] = None,
) -> tuple[t.Optional[str], t.Optional[str]]:
"""Run the specified command and return the result."""
options = []
if self.user:
options.extend(['--user', self.user])
if interactive:
options.append('-it')
return docker_exec(
args=self.args,
container_id=self.container_id,
cmd=command,
options=options,
capture=capture,
data=data,
stdin=stdin,
stdout=stdout,
interactive=interactive,
output_stream=output_stream,
)
def inspect(self) -> DockerInspect:
"""Inspect the container and return a DockerInspect instance with the results."""
return docker_inspect(self.args, self.container_id)
def disconnect_network(self, network: str) -> None:
"""Disconnect the container from the specified network."""
docker_network_disconnect(self.args, self.container_id, network)