diff --git a/dumb-init.c b/dumb-init.c index ac7fd15..febd4b8 100644 --- a/dumb-init.c +++ b/dumb-init.c @@ -28,40 +28,18 @@ pid_t child_pid = -1; char debug = 0; char use_setsid = 1; -void signal_handler(int signum) { - DEBUG("Received signal %d.\n", signum); - +void forward_signal(int signum) { if (child_pid > 0) { kill(use_setsid ? -child_pid : child_pid, signum); - DEBUG("Forwarded signal to child.\n"); + DEBUG("Forwarded signal %d to child.\n", signum); } else { - DEBUG("Didn't forward signal, no child exists yet."); + DEBUG("Didn't forward signal %d, no child exists yet.\n", signum); } } -void reap_zombies(int signum) { - /* - * As PID 1, dumb-init is expected to handle reaping of zombie processes. - * - * If a process's parent exits, the child is orphaned and its new parent is - * PID 1. If that child later exits, it becomes a zombie process until its - * parent (now dumb-init) calls wait() on it. - */ - int status, exit_status; - pid_t killed_pid; - - assert(signum == SIGCHLD); - DEBUG("Received SIGCHLD, calling waitpid().\n"); - - while ((killed_pid = waitpid(-1, &status, WNOHANG)) > 0) { - exit_status = WEXITSTATUS(status); - DEBUG("A child with PID %d exited with exit status %d.\n", killed_pid, exit_status); - - if (killed_pid == child_pid) { - DEBUG("Child exited with status %d. Goodbye.\n", exit_status); - exit(exit_status); - } - } +void handle_signal(int signum) { + DEBUG("Received signal %d.\n", signum); + forward_signal(signum); } void print_help(char *argv[]) { @@ -120,14 +98,12 @@ int main(int argc, char *argv[]) { if (signum == SIGKILL || signum == SIGSTOP || signum == SIGCHLD) continue; - if (signal(signum, signal_handler) == SIG_ERR) { + if (signal(signum, handle_signal) == SIG_ERR) { fprintf(stderr, "Error: Couldn't register signal handler for signal `%d`. Exiting.\n", signum); return 1; } } - signal(SIGCHLD, reap_zombies); - /* launch our process */ child_pid = fork(); @@ -153,9 +129,22 @@ int main(int argc, char *argv[]) { execvp(argv[1], &argv[1]); } else { + pid_t killed_pid; + int exit_status, status; + DEBUG("Child spawned with PID %d.\n", child_pid); - for (;;) { - pause(); + + while ((killed_pid = waitpid(-1, &status, 0))) { + exit_status = WEXITSTATUS(status); + DEBUG("A child with PID %d exited with exit status %d.\n", killed_pid, exit_status); + + if (killed_pid == child_pid) { + // send SIGTERM to any remaining children + forward_signal(SIGTERM); + + DEBUG("Child exited with status %d. Goodbye.\n", exit_status); + exit(exit_status); + } } } diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..1ea6b80 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +timeout = 5 diff --git a/requirements-dev.txt b/requirements-dev.txt index c906dfc..2dac807 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,3 @@ pre-commit>=0.5.0 pytest +pytest-timeout diff --git a/tests/child_processes_test.py b/tests/child_processes_test.py index 6fff40f..5e2d686 100644 --- a/tests/child_processes_test.py +++ b/tests/child_processes_test.py @@ -1,6 +1,9 @@ import os +import re import signal +import sys import time +from subprocess import PIPE from subprocess import Popen from tests.lib.testing import is_alive @@ -29,23 +32,94 @@ def living_pids(pids): return set(pid for pid in pids if is_alive(pid)) -def test_setsid_signals_entire_group(both_debug_modes): +def test_setsid_signals_entire_group(both_debug_modes, setsid_enabled): """When dumb-init is running in setsid mode, it should only signal the entire process group rooted at it. """ - os.environ['DUMB_INIT_SETSID'] = '1' pids = spawn_and_kill_pipeline() assert len(living_pids(pids)) == 0 -def test_no_setsid_doesnt_signal_entire_group(both_debug_modes): +def test_no_setsid_doesnt_signal_entire_group( + both_debug_modes, + setsid_disabled, +): """When dumb-init is not running in setsid mode, it should only signal its immediate child. """ - os.environ['DUMB_INIT_SETSID'] = '0' pids = spawn_and_kill_pipeline() living = living_pids(pids) assert len(living) == 4 for pid in living: os.kill(pid, signal.SIGKILL) + + +def spawn_process_which_dies_with_children(): + """Spawn a process which spawns some children and then dies without + signaling them, wrapped in dumb-init. + + Returns a tuple (child pid, child stdout pipe), where the child is + print_signals. This is useful because you can signal the PID and see if + anything gets printed onto the stdout pipe. + """ + proc = Popen( + ( + 'dumb-init', + 'sh', '-c', + + # we need to sleep before the shell exits, or dumb-init might send + # TERM to print_signals before it has had time to register custom + # signal handlers + '{python} -m tests.lib.print_signals & sleep 0.1'.format( + python=sys.executable, + ), + ), + stdout=PIPE, + ) + proc.wait() + assert proc.returncode == 0 + + # read a line from print_signals, figure out its pid + line = proc.stdout.readline() + match = re.match(b'ready \(pid: ([0-9]+)\)\n', line) + assert match, 'print_signals should print "ready" and its pid, not ' + \ + str(line) + child_pid = int(match.group(1)) + + # at this point, the shell and dumb-init have both exited, but + # print_signals may or may not still be running (depending on whether + # setsid mode is enabled) + + return child_pid, proc.stdout + + +def test_all_processes_receive_term_on_exit_if_setsid( + both_debug_modes, + setsid_enabled, +): + """If the child exits for some reason, dumb-init should send TERM to all + processes in its session if setsid mode is enabled.""" + child_pid, child_stdout = spawn_process_which_dies_with_children() + + # print_signals should have received TERM + assert child_stdout.readline() == b'15\n' + + os.kill(child_pid, signal.SIGKILL) + + +def test_processes_dont_receive_term_on_exit_if_no_setsid( + both_debug_modes, + setsid_disabled, +): + """If the child exits for some reason, dumb-init should not send TERM to + any other processes if setsid mode is disabled.""" + child_pid, child_stdout = spawn_process_which_dies_with_children() + + # print_signals should not have received TERM; to test this, we send it + # some other signals and ensure they were received (and TERM wasn't) + for signum in [1, 2, 3]: + os.kill(child_pid, signum) + assert child_stdout.readline() == str(signum).encode('ascii') + b'\n' + + os.kill(child_pid, signal.SIGKILL) diff --git a/tests/conftest.py b/tests/conftest.py index 674be28..7cc9833 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,21 @@ def both_debug_modes(request): os.environ['DUMB_INIT_DEBUG'] = request.param +@pytest.fixture +def debug_disabled(): + os.environ['DUMB_INIT_DEBUG'] = '0' + + @pytest.fixture(params=['1', '0']) def both_setsid_modes(request): os.environ['DUMB_INIT_SETSID'] = request.param + + +@pytest.fixture +def setsid_enabled(): + os.environ['DUMB_INIT_SETSID'] = '1' + + +@pytest.fixture +def setsid_disabled(): + os.environ['DUMB_INIT_SETSID'] = '0' diff --git a/tests/lib/print_signals.py b/tests/lib/print_signals.py index 1388fb5..2b79d28 100755 --- a/tests/lib/print_signals.py +++ b/tests/lib/print_signals.py @@ -6,6 +6,7 @@ SIGKILL (kill -9) to this process to actually end it. """ from __future__ import print_function +import os import signal import sys import time @@ -29,7 +30,7 @@ if __name__ == '__main__': for signum in CATCHABLE_SIGNALS: signal.signal(signum, print_signal) - unbuffered_print('ready') + unbuffered_print('ready (pid: {0})'.format(os.getpid())) # loop forever just printing signals while True: diff --git a/tests/proxies_signals_test.py b/tests/proxies_signals_test.py index 65efeb2..a05fd78 100644 --- a/tests/proxies_signals_test.py +++ b/tests/proxies_signals_test.py @@ -1,4 +1,5 @@ import os +import re import signal import sys from subprocess import PIPE @@ -14,7 +15,7 @@ def test_prints_signals(both_debug_modes, both_setsid_modes): stdout=PIPE, ) - assert proc.stdout.readline() == b'ready\n' + assert re.match(b'^ready \(pid: (?:[0-9]+)\)\n$', proc.stdout.readline()) for signum in CATCHABLE_SIGNALS: proc.send_signal(signum) diff --git a/tests/tty_test.py b/tests/tty_test.py index fb311b4..65d9959 100644 --- a/tests/tty_test.py +++ b/tests/tty_test.py @@ -1,6 +1,3 @@ -import os - - EOF = b'\x04' @@ -53,12 +50,11 @@ def _test(fd): print('PASS') -def test_tty(): +# disable debug output so it doesn't break our assertion +def test_tty(debug_disabled): """ Ensure processes wrapped by dumb-init can write successfully, given a tty """ - # disable debug output so it doesn't break our assertion - os.environ['DUMB_INIT_DEBUG'] = '0' import pty pid, fd = pty.fork() if pid == 0: