Prevent stdio deadlock in forked children (#79522)

* background threads writing to stdout/stderr can cause children to deadlock if a thread in the parent holds the internal lock on the BufferedWriter wrapper
* prevent writes to std handles during fork by monkeypatching stdout/stderr during display startup to require a mutex lock with fork(); this ensures no background threads can hold the lock during a fork operation
* add integration test that fails reliably on Linux without this fix
pull/79541/head
Matt Davis 2 years ago committed by GitHub
parent 80d2f8da02
commit 1424484be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,2 @@
bugfixes:
- display - reduce risk of post-fork output deadlocks (https://github.com/ansible/ansible/pull/79522)

@ -87,10 +87,12 @@ class WorkerProcess(multiprocessing_context.Process): # type: ignore[name-defin
'''
self._save_stdin()
try:
return super(WorkerProcess, self).start()
finally:
self._new_stdin.close()
# FUTURE: this lock can be removed once a more generalized pre-fork thread pause is in place
with display._lock:
try:
return super(WorkerProcess, self).start()
finally:
self._new_stdin.close()
def _hard_exit(self, e):
'''

@ -41,6 +41,7 @@ from ansible.utils.color import stringc
from ansible.utils.multiprocessing import context as multiprocessing_context
from ansible.utils.singleton import Singleton
from ansible.utils.unsafe_proxy import wrap_var
from functools import wraps
_LIBC = ctypes.cdll.LoadLibrary(ctypes.util.find_library('c'))
@ -163,12 +164,33 @@ b_COW_PATHS = (
)
def _synchronize_textiowrapper(tio, lock):
# Ensure that a background thread can't hold the internal buffer lock on a file object
# during a fork, which causes forked children to hang. We're using display's existing lock for
# convenience (and entering the lock before a fork).
def _wrap_with_lock(f, lock):
@wraps(f)
def locking_wrapper(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return locking_wrapper
buffer = tio.buffer
# monkeypatching the underlying file-like object isn't great, but likely safer than subclassing
buffer.write = _wrap_with_lock(buffer.write, lock)
buffer.flush = _wrap_with_lock(buffer.flush, lock)
class Display(metaclass=Singleton):
def __init__(self, verbosity=0):
self._final_q = None
# NB: this lock is used to both prevent intermingled output between threads and to block writes during forks.
# Do not change the type of this lock or upgrade to a shared lock (eg multiprocessing.RLock).
self._lock = threading.RLock()
self.columns = None
@ -199,6 +221,13 @@ class Display(metaclass=Singleton):
self._set_column_width()
try:
# NB: we're relying on the display singleton behavior to ensure this only runs once
_synchronize_textiowrapper(sys.stdout, self._lock)
_synchronize_textiowrapper(sys.stderr, self._lock)
except Exception as ex:
self.warning(f"failed to patch stdout/stderr for fork-safety: {ex}")
def set_queue(self, queue):
"""Set the _final_q on Display, so that we know to proxy display over the queue
instead of directly writing to stdout/stderr from forks

@ -0,0 +1,3 @@
shippable/posix/group3
context/controller
skip/macos

@ -0,0 +1,58 @@
import atexit
import os
import sys
from ansible.plugins.callback import CallbackBase
from ansible.utils.display import Display
from threading import Thread
# This callback plugin reliably triggers the deadlock from https://github.com/ansible/ansible-runner/issues/1164 when
# run on a TTY/PTY. It starts a thread in the controller that spews unprintable characters to stdout as fast as
# possible, while causing forked children to write directly to the inherited stdout immediately post-fork. If a fork
# occurs while the spew thread holds stdout's internal BufferedIOWriter lock, the lock will be orphaned in the child,
# and attempts to write to stdout there will hang forever.
# Any mechanism that ensures non-main threads do not hold locks before forking should allow this test to pass.
# ref: https://docs.python.org/3/library/io.html#multi-threading
# ref: https://github.com/python/cpython/blob/0547a981ae413248b21a6bb0cb62dda7d236fe45/Modules/_io/bufferedio.c#L268
class CallbackModule(CallbackBase):
CALLBACK_VERSION = 2.0
CALLBACK_NAME = 'spewstdio'
def __init__(self):
super().__init__()
self.display = Display()
if os.environ.get('SPEWSTDIO_ENABLED', '0') != '1':
self.display.warning('spewstdio test plugin loaded but disabled; set SPEWSTDIO_ENABLED=1 to enable')
return
self.display = Display()
self._keep_spewing = True
# cause the child to write directly to stdout immediately post-fork
os.register_at_fork(after_in_child=lambda: print(f"hi from forked child pid {os.getpid()}"))
# in passing cases, stop spewing when the controller is exiting to prevent fatal errors on final flush
atexit.register(self.stop_spew)
self._spew_thread = Thread(target=self.spew, daemon=True)
self._spew_thread.start()
def stop_spew(self):
self._keep_spewing = False
def spew(self):
# dump a message so we know the callback thread has started
self.display.warning("spewstdio STARTING NONPRINTING SPEW ON BACKGROUND THREAD")
while self._keep_spewing:
# dump a non-printing control character directly to stdout to avoid junking up the screen while still
# doing lots of writes and flushes.
sys.stdout.write('\x1b[K')
sys.stdout.flush()
self.display.warning("spewstdio STOPPING SPEW")

@ -0,0 +1,5 @@
[all]
local-[1:10]
[all:vars]
ansible_connection=local

@ -0,0 +1,11 @@
#!/usr/bin/env python
"""Run a command using a PTY."""
import sys
if sys.version_info < (3, 10):
import vendored_pty as pty
else:
import pty
sys.exit(1 if pty.spawn(sys.argv[1:]) else 0)

@ -0,0 +1,20 @@
#!/usr/bin/env bash
set -eu
echo "testing for stdio deadlock on forked workers (10s timeout)..."
# Enable a callback that trips deadlocks on forked-child stdout, time out after 10s; forces running
# in a pty, since that tends to be much slower than raw file I/O and thus more likely to trigger the deadlock.
# Redirect stdout to /dev/null since it's full of non-printable garbage we don't want to display unless it failed
ANSIBLE_CALLBACKS_ENABLED=spewstdio SPEWSTDIO_ENABLED=1 python run-with-pty.py timeout 10s ansible-playbook -i hosts -f 5 test.yml > stdout.txt && RC=$? || RC=$?
if [ $RC != 0 ]; then
echo "failed; likely stdout deadlock. dumping raw output (may be very large)"
cat stdout.txt
exit 1
fi
grep -q -e "spewstdio STARTING NONPRINTING SPEW ON BACKGROUND THREAD" stdout.txt || (echo "spewstdio callback was not enabled"; exit 1)
echo "PASS"

@ -0,0 +1,5 @@
- hosts: all
gather_facts: no
tasks:
- debug:
msg: yo

@ -0,0 +1,189 @@
# Vendored copy of https://github.com/python/cpython/blob/3680ebed7f3e529d01996dd0318601f9f0d02b4b/Lib/pty.py
# PSF License (see licenses/PSF-license.txt or https://opensource.org/licenses/Python-2.0)
"""Pseudo terminal utilities."""
# Bugs: No signal handling. Doesn't set slave termios and window size.
# Only tested on Linux, FreeBSD, and macOS.
# See: W. Richard Stevens. 1992. Advanced Programming in the
# UNIX Environment. Chapter 19.
# Author: Steen Lumholt -- with additions by Guido.
from select import select
import os
import sys
import tty
# names imported directly for test mocking purposes
from os import close, waitpid
from tty import setraw, tcgetattr, tcsetattr
__all__ = ["openpty", "fork", "spawn"]
STDIN_FILENO = 0
STDOUT_FILENO = 1
STDERR_FILENO = 2
CHILD = 0
def openpty():
"""openpty() -> (master_fd, slave_fd)
Open a pty master/slave pair, using os.openpty() if possible."""
try:
return os.openpty()
except (AttributeError, OSError):
pass
master_fd, slave_name = _open_terminal()
slave_fd = slave_open(slave_name)
return master_fd, slave_fd
def master_open():
"""master_open() -> (master_fd, slave_name)
Open a pty master and return the fd, and the filename of the slave end.
Deprecated, use openpty() instead."""
try:
master_fd, slave_fd = os.openpty()
except (AttributeError, OSError):
pass
else:
slave_name = os.ttyname(slave_fd)
os.close(slave_fd)
return master_fd, slave_name
return _open_terminal()
def _open_terminal():
"""Open pty master and return (master_fd, tty_name)."""
for x in 'pqrstuvwxyzPQRST':
for y in '0123456789abcdef':
pty_name = '/dev/pty' + x + y
try:
fd = os.open(pty_name, os.O_RDWR)
except OSError:
continue
return (fd, '/dev/tty' + x + y)
raise OSError('out of pty devices')
def slave_open(tty_name):
"""slave_open(tty_name) -> slave_fd
Open the pty slave and acquire the controlling terminal, returning
opened filedescriptor.
Deprecated, use openpty() instead."""
result = os.open(tty_name, os.O_RDWR)
try:
from fcntl import ioctl, I_PUSH
except ImportError:
return result
try:
ioctl(result, I_PUSH, "ptem")
ioctl(result, I_PUSH, "ldterm")
except OSError:
pass
return result
def fork():
"""fork() -> (pid, master_fd)
Fork and make the child a session leader with a controlling terminal."""
try:
pid, fd = os.forkpty()
except (AttributeError, OSError):
pass
else:
if pid == CHILD:
try:
os.setsid()
except OSError:
# os.forkpty() already set us session leader
pass
return pid, fd
master_fd, slave_fd = openpty()
pid = os.fork()
if pid == CHILD:
# Establish a new session.
os.setsid()
os.close(master_fd)
# Slave becomes stdin/stdout/stderr of child.
os.dup2(slave_fd, STDIN_FILENO)
os.dup2(slave_fd, STDOUT_FILENO)
os.dup2(slave_fd, STDERR_FILENO)
if slave_fd > STDERR_FILENO:
os.close(slave_fd)
# Explicitly open the tty to make it become a controlling tty.
tmp_fd = os.open(os.ttyname(STDOUT_FILENO), os.O_RDWR)
os.close(tmp_fd)
else:
os.close(slave_fd)
# Parent and child process.
return pid, master_fd
def _writen(fd, data):
"""Write all the data to a descriptor."""
while data:
n = os.write(fd, data)
data = data[n:]
def _read(fd):
"""Default read function."""
return os.read(fd, 1024)
def _copy(master_fd, master_read=_read, stdin_read=_read):
"""Parent copy loop.
Copies
pty master -> standard output (master_read)
standard input -> pty master (stdin_read)"""
fds = [master_fd, STDIN_FILENO]
while fds:
rfds, _wfds, _xfds = select(fds, [], [])
if master_fd in rfds:
# Some OSes signal EOF by returning an empty byte string,
# some throw OSErrors.
try:
data = master_read(master_fd)
except OSError:
data = b""
if not data: # Reached EOF.
return # Assume the child process has exited and is
# unreachable, so we clean up.
else:
os.write(STDOUT_FILENO, data)
if STDIN_FILENO in rfds:
data = stdin_read(STDIN_FILENO)
if not data:
fds.remove(STDIN_FILENO)
else:
_writen(master_fd, data)
def spawn(argv, master_read=_read, stdin_read=_read):
"""Create a spawned process."""
if isinstance(argv, str):
argv = (argv,)
sys.audit('pty.spawn', argv)
pid, master_fd = fork()
if pid == CHILD:
os.execlp(argv[0], *argv)
try:
mode = tcgetattr(STDIN_FILENO)
setraw(STDIN_FILENO)
restore = True
except tty.error: # This is the same as termios.error
restore = False
try:
_copy(master_fd, master_read, stdin_read)
finally:
if restore:
tcsetattr(STDIN_FILENO, tty.TCSAFLUSH, mode)
close(master_fd)
return waitpid(pid, 0)[1]

@ -48,11 +48,7 @@ def main():
__import__(name)
return sys.modules[name]
try:
# noinspection PyCompatibility
from StringIO import StringIO
except ImportError:
from io import StringIO
from io import BytesIO, TextIOWrapper
try:
from importlib.util import spec_from_loader, module_from_spec
@ -436,8 +432,9 @@ def main():
class Capture:
"""Captured output and/or exception."""
def __init__(self):
self.stdout = StringIO()
self.stderr = StringIO()
# use buffered IO to simulate StringIO; allows Ansible's stream patching to behave without warnings
self.stdout = TextIOWrapper(BytesIO())
self.stderr = TextIOWrapper(BytesIO())
def capture_report(path, capture, messages):
"""Report on captured output.
@ -445,12 +442,17 @@ def main():
:type capture: Capture
:type messages: set[str]
"""
if capture.stdout.getvalue():
first = capture.stdout.getvalue().strip().splitlines()[0].strip()
# since we're using buffered IO, flush before checking for data
capture.stdout.flush()
capture.stderr.flush()
stdout_value = capture.stdout.buffer.getvalue()
if stdout_value:
first = stdout_value.decode().strip().splitlines()[0].strip()
report_message(path, 0, 0, 'stdout', first, messages)
if capture.stderr.getvalue():
first = capture.stderr.getvalue().strip().splitlines()[0].strip()
stderr_value = capture.stderr.buffer.getvalue()
if stderr_value:
first = stderr_value.decode().strip().splitlines()[0].strip()
report_message(path, 0, 0, 'stderr', first, messages)
def report_message(path, line, column, code, message, messages):

@ -133,6 +133,7 @@ test/integration/targets/ansible-test-docker/ansible_collections/ns/col/tests/un
test/integration/targets/ansible-test-no-tty/ansible_collections/ns/col/vendored_pty.py pep8!skip # vendored code
test/integration/targets/collections_relative_imports/collection_root/ansible_collections/my_ns/my_col/plugins/modules/my_module.py pylint:relative-beyond-top-level
test/integration/targets/collections_relative_imports/collection_root/ansible_collections/my_ns/my_col/plugins/module_utils/my_util2.py pylint:relative-beyond-top-level
test/integration/targets/fork_safe_stdio/vendored_pty.py pep8!skip # vendored code
test/integration/targets/gathering_facts/library/bogus_facts shebang
test/integration/targets/gathering_facts/library/facts_one shebang
test/integration/targets/gathering_facts/library/facts_two shebang

Loading…
Cancel
Save