Fix signal propagation (#85907)

pull/85985/head
sivel / Matt Martz 2 months ago committed by GitHub
parent 9ee667030f
commit 5a9afe4409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,3 @@
bugfixes:
- SIGINT/SIGTERM Handling - Make SIGINT/SIGTERM handling more robust by splitting concerns
between forks and the parent.

@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import errno
import io import io
import os import os
import signal import signal
@ -103,11 +104,19 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
self._cliargs = cliargs self._cliargs = cliargs
def _term(self, signum, frame) -> None: def _term(self, signum, frame) -> None:
""" """In child termination when notified by the parent"""
terminate the process group created by calling setsid when signal.signal(signum, signal.SIG_DFL)
a terminate signal is received by the fork
""" try:
os.killpg(self.pid, signum) os.killpg(self.pid, signum)
os.kill(self.pid, signum)
except OSError as e:
if e.errno != errno.ESRCH:
signame = signal.strsignal(signum)
display.error(f'Unable to send {signame} to child[{self.pid}]: {e}')
# fallthrough, if we are still here, just die
os._exit(1)
def start(self) -> None: def start(self) -> None:
""" """
@ -121,11 +130,6 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
# FUTURE: this lock can be removed once a more generalized pre-fork thread pause is in place # FUTURE: this lock can be removed once a more generalized pre-fork thread pause is in place
with display._lock: with display._lock:
super(WorkerProcess, self).start() super(WorkerProcess, self).start()
# Since setsid is called later, if the worker is termed
# it won't term the new process group
# register a handler to propagate the signal
signal.signal(signal.SIGTERM, self._term)
signal.signal(signal.SIGINT, self._term)
def _hard_exit(self, e: str) -> t.NoReturn: def _hard_exit(self, e: str) -> t.NoReturn:
""" """
@ -170,7 +174,6 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
# to give better errors, and to prevent fd 0 reuse # to give better errors, and to prevent fd 0 reuse
sys.stdin.close() sys.stdin.close()
except Exception as e: except Exception as e:
display.debug(f'Could not detach from stdio: {traceback.format_exc()}')
display.error(f'Could not detach from stdio: {e}') display.error(f'Could not detach from stdio: {e}')
os._exit(1) os._exit(1)
@ -187,6 +190,9 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
# Set the queue on Display so calls to Display.display are proxied over the queue # Set the queue on Display so calls to Display.display are proxied over the queue
display.set_queue(self._final_q) display.set_queue(self._final_q)
self._detach() self._detach()
# propagate signals
signal.signal(signal.SIGINT, self._term)
signal.signal(signal.SIGTERM, self._term)
try: try:
with _task.TaskContext(self._task): with _task.TaskContext(self._task):
return self._run() return self._run()

@ -18,8 +18,10 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import errno
import os import os
import sys import sys
import signal
import tempfile import tempfile
import threading import threading
import time import time
@ -185,8 +187,48 @@ class TaskQueueManager:
# plugins for inter-process locking. # plugins for inter-process locking.
self._connection_lockfile = tempfile.TemporaryFile() self._connection_lockfile = tempfile.TemporaryFile()
self._workers: list[WorkerProcess | None] = []
# signal handlers to propagate signals to workers
signal.signal(signal.SIGTERM, self._signal_handler)
signal.signal(signal.SIGINT, self._signal_handler)
def _initialize_processes(self, num: int) -> None: def _initialize_processes(self, num: int) -> None:
self._workers: list[WorkerProcess | None] = [None] * num # mutable update to ensure the reference stays the same
self._workers[:] = [None] * num
def _signal_handler(self, signum, frame) -> None:
"""
terminate all running process groups created as a result of calling
setsid from within a WorkerProcess.
Since the children become process leaders, signals will not
automatically propagate to them.
"""
signal.signal(signum, signal.SIG_DFL)
for worker in self._workers:
if worker is None or not worker.is_alive():
continue
if worker.pid:
try:
# notify workers
os.kill(worker.pid, signum)
except OSError as e:
if e.errno != errno.ESRCH:
signame = signal.strsignal(signum)
display.error(f'Unable to send {signame} to child[{worker.pid}]: {e}')
if signum == signal.SIGINT:
# Defer to CLI handling
raise KeyboardInterrupt()
pid = os.getpid()
try:
os.kill(pid, signum)
except OSError as e:
signame = signal.strsignal(signum)
display.error(f'Unable to send {signame} to {pid}: {e}')
def load_callbacks(self): def load_callbacks(self):
""" """

@ -2,4 +2,5 @@ needs/ssh
shippable/posix/group3 shippable/posix/group3
needs/target/connection needs/target/connection
needs/target/setup_test_user needs/target/setup_test_user
needs/target/test_utils
setup/always/setup_passlib_controller # required for setup_test_user setup/always/setup_passlib_controller # required for setup_test_user

@ -17,7 +17,7 @@ if command -v sshpass > /dev/null; then
# ansible with timeout. If we time out, our custom prompt was successfully # ansible with timeout. If we time out, our custom prompt was successfully
# searched for. It's a weird way of doing things, but it does ensure # searched for. It's a weird way of doing things, but it does ensure
# that the flag gets passed to sshpass. # that the flag gets passed to sshpass.
timeout 5 ansible -m ping \ ../test_utils/scripts/timeout.py 5 -- ansible -m ping \
-e ansible_connection=ssh \ -e ansible_connection=ssh \
-e ansible_ssh_password_mechanism=sshpass \ -e ansible_ssh_password_mechanism=sshpass \
-e ansible_sshpass_prompt=notThis: \ -e ansible_sshpass_prompt=notThis: \

@ -0,0 +1,3 @@
shippable/posix/group4
context/controller
needs/target/test_utils

@ -0,0 +1,14 @@
localhost0
localhost1
localhost2
localhost3
localhost4
localhost5
localhost6
localhost7
localhost8
localhost9
[all:vars]
ansible_connection=local
ansible_python_interpreter={{ansible_playbook_python}}

@ -0,0 +1,21 @@
#!/usr/bin/env bash
set -x
../test_utils/scripts/timeout.py -s SIGINT 3 -- \
ansible all -i inventory -m debug -a 'msg={{lookup("pipe", "sleep 33")}}' -f 10
if [[ "$?" != "124" ]]; then
echo "Process was not terminated due to timeout"
exit 1
fi
# a short sleep to let processes die
sleep 2
sleeps="$(pgrep -alf 'sleep\ 33')"
rc="$?"
if [[ "$rc" == "0" ]]; then
echo "Found lingering processes:"
echo "$sleeps"
exit 1
fi

@ -2,21 +2,32 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import signal
import subprocess import subprocess
import sys import sys
def signal_type(v: str) -> signal.Signals:
if v.isdecimal():
return signal.Signals(int(v))
if not v.startswith('SIG'):
v = f'SIG{v}'
return getattr(signal.Signals, v)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('duration', type=int) parser.add_argument('duration', type=int)
parser.add_argument('--signal', '-s', default=signal.SIGTERM, type=signal_type)
parser.add_argument('command', nargs='+') parser.add_argument('command', nargs='+')
args = parser.parse_args() args = parser.parse_args()
p: subprocess.Popen | None = None
try: try:
p = subprocess.run( p = subprocess.Popen(args.command)
' '.join(args.command), p.wait(timeout=args.duration)
shell=True,
timeout=args.duration,
check=False,
)
sys.exit(p.returncode) sys.exit(p.returncode)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
if p and p.poll() is None:
p.send_signal(args.signal)
p.wait()
sys.exit(124) sys.exit(124)

Loading…
Cancel
Save