diff --git a/docs/deployment.md b/docs/deployment.md index 478632b28..7a2c7972c 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -178,7 +178,23 @@ Running Uvicorn using a process manager ensures that you can run multiple proces A process manager will handle the socket setup, start-up multiple server processes, monitor process aliveness, and listen for signals to provide for processes restarts, shutdowns, or dialing up and down the number of running processes. -Uvicorn provides a lightweight way to run multiple worker processes, for example `--workers 4`, but does not provide any process monitoring. +### Built-in + +Uvicorn includes a `--workers` option that allows you to run multiple worker processes. + +```bash +$ uvicorn main:app --workers 4 +``` + +Unlike gunicorn, uvicorn does not use pre-fork, but uses [`spawn`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods), which allows uvicorn's multiprocess manager to still work well on Windows. + +The default process manager monitors the status of child processes and automatically restarts child processes that die unexpectedly. Not only that, it will also monitor the status of the child process through the pipeline. When the child process is accidentally stuck, the corresponding child process will be killed through an unstoppable system signal or interface. + +You can also manage child processes by sending specific signals to the main process. (Not supported on Windows.) + +- `SIGHUP`: Work processeses are graceful restarted one after another. If you update the code, the new worker process will use the new code. +- `SIGTTIN`: Increase the number of worker processes by one. +- `SIGTTOU`: Decrease the number of worker processes by one. ### Gunicorn diff --git a/pyproject.toml b/pyproject.toml index b59d4e0e0..9fb3e3315 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,11 @@ exclude_lines = [ ] [tool.coverage.coverage_conditional_plugin.omit] -"sys_platform == 'win32'" = ["uvicorn/loops/uvloop.py"] +"sys_platform == 'win32'" = [ + "uvicorn/loops/uvloop.py", + "uvicorn/supervisors/multiprocess.py", + "tests/supervisors/test_multiprocess.py", +] "sys_platform != 'win32'" = ["uvicorn/loops/asyncio.py"] [tool.coverage.coverage_conditional_plugin.rules] diff --git a/tests/supervisors/__init__.py b/tests/supervisors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/supervisors/test_multiprocess.py b/tests/supervisors/test_multiprocess.py index 391b66a73..5365907aa 100644 --- a/tests/supervisors/test_multiprocess.py +++ b/tests/supervisors/test_multiprocess.py @@ -1,11 +1,43 @@ from __future__ import annotations +import functools +import os import signal import socket +import threading +import time +from typing import Any, Callable + +import pytest from uvicorn import Config from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.supervisors import Multiprocess +from uvicorn.supervisors.multiprocess import Process + + +def new_console_in_windows(test_function: Callable[[], Any]) -> Callable[[], Any]: # pragma: no cover + if os.name != "nt": + return test_function + + @functools.wraps(test_function) + def new_function(): + import subprocess + import sys + + module = test_function.__module__ + name = test_function.__name__ + + subprocess.check_call( + [ + sys.executable, + "-c", + f"from {module} import {name}; {name}.__wrapped__()", + ], + creationflags=subprocess.CREATE_NO_WINDOW, # type: ignore[attr-defined] + ) + + return new_function async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: @@ -13,9 +45,22 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable def run(sockets: list[socket.socket] | None) -> None: - pass # pragma: no cover + while True: + time.sleep(1) + + +def test_process_ping_pong() -> None: + process = Process(Config(app=app), target=lambda x: None, sockets=[]) + threading.Thread(target=process.always_pong, daemon=True).start() + assert process.ping() + + +def test_process_ping_pong_timeout() -> None: + process = Process(Config(app=app), target=lambda x: None, sockets=[]) + assert not process.ping(0.1) +@new_console_in_windows def test_multiprocess_run() -> None: """ A basic sanity check. @@ -25,5 +70,102 @@ def test_multiprocess_run() -> None: """ config = Config(app=app, workers=2) supervisor = Multiprocess(config, target=run, sockets=[]) - supervisor.signal_handler(sig=signal.SIGINT, frame=None) - supervisor.run() + threading.Thread(target=supervisor.run, daemon=True).start() + supervisor.signal_queue.append(signal.SIGINT) + supervisor.join_all() + + +@new_console_in_windows +def test_multiprocess_health_check() -> None: + """ + Ensure that the health check works as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + time.sleep(1) + process = supervisor.processes[0] + process.kill() + assert not process.is_alive() + time.sleep(1) + for p in supervisor.processes: + assert p.is_alive() + supervisor.signal_queue.append(signal.SIGINT) + supervisor.join_all() + + +@new_console_in_windows +def test_multiprocess_sigterm() -> None: + """ + Ensure that the SIGTERM signal is handled as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + time.sleep(1) + supervisor.signal_queue.append(signal.SIGTERM) + supervisor.join_all() + + +@pytest.mark.skipif(not hasattr(signal, "SIGBREAK"), reason="platform unsupports SIGBREAK") +@new_console_in_windows +def test_multiprocess_sigbreak() -> None: # pragma: py-not-win32 + """ + Ensure that the SIGBREAK signal is handled as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + time.sleep(1) + supervisor.signal_queue.append(getattr(signal, "SIGBREAK")) + supervisor.join_all() + + +@pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="platform unsupports SIGHUP") +def test_multiprocess_sighup() -> None: + """ + Ensure that the SIGHUP signal is handled as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + time.sleep(1) + pids = [p.pid for p in supervisor.processes] + supervisor.signal_queue.append(signal.SIGHUP) + time.sleep(1) + assert pids != [p.pid for p in supervisor.processes] + supervisor.signal_queue.append(signal.SIGINT) + supervisor.join_all() + + +@pytest.mark.skipif(not hasattr(signal, "SIGTTIN"), reason="platform unsupports SIGTTIN") +def test_multiprocess_sigttin() -> None: + """ + Ensure that the SIGTTIN signal is handled as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + supervisor.signal_queue.append(signal.SIGTTIN) + time.sleep(1) + assert len(supervisor.processes) == 3 + supervisor.signal_queue.append(signal.SIGINT) + supervisor.join_all() + + +@pytest.mark.skipif(not hasattr(signal, "SIGTTOU"), reason="platform unsupports SIGTTOU") +def test_multiprocess_sigttou() -> None: + """ + Ensure that the SIGTTOU signal is handled as expected. + """ + config = Config(app=app, workers=2) + supervisor = Multiprocess(config, target=run, sockets=[]) + threading.Thread(target=supervisor.run, daemon=True).start() + supervisor.signal_queue.append(signal.SIGTTOU) + time.sleep(1) + assert len(supervisor.processes) == 1 + supervisor.signal_queue.append(signal.SIGTTOU) + time.sleep(1) + assert len(supervisor.processes) == 1 + supervisor.signal_queue.append(signal.SIGINT) + supervisor.join_all() diff --git a/uvicorn/supervisors/multiprocess.py b/uvicorn/supervisors/multiprocess.py index e0916721b..c242fed9a 100644 --- a/uvicorn/supervisors/multiprocess.py +++ b/uvicorn/supervisors/multiprocess.py @@ -4,24 +4,101 @@ import os import signal import threading -from multiprocessing.context import SpawnProcess +from multiprocessing import Pipe from socket import socket -from types import FrameType -from typing import Callable +from typing import Any, Callable import click from uvicorn._subprocess import get_subprocess from uvicorn.config import Config -HANDLED_SIGNALS = ( - signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. - signal.SIGTERM, # Unix signal 15. Sent by `kill `. -) +SIGNALS = { + getattr(signal, f"SIG{x}"): x + for x in "INT TERM BREAK HUP QUIT TTIN TTOU USR1 USR2 WINCH".split() + if hasattr(signal, f"SIG{x}") +} logger = logging.getLogger("uvicorn.error") +class Process: + def __init__( + self, + config: Config, + target: Callable[[list[socket] | None], None], + sockets: list[socket], + ) -> None: + self.real_target = target + + self.parent_conn, self.child_conn = Pipe() + self.process = get_subprocess(config, self.target, sockets) + + def ping(self, timeout: float = 5) -> bool: + self.parent_conn.send(b"ping") + if self.parent_conn.poll(timeout): + self.parent_conn.recv() + return True + return False + + def pong(self) -> None: + self.child_conn.recv() + self.child_conn.send(b"pong") + + def always_pong(self) -> None: + while True: + self.pong() + + def target(self, sockets: list[socket] | None = None) -> Any: # pragma: no cover + if os.name == "nt": # pragma: py-not-win32 + # Windows doesn't support SIGTERM, so we use SIGBREAK instead. + # And then we raise SIGTERM when SIGBREAK is received. + # https://learn.microsoft.com/zh-cn/cpp/c-runtime-library/reference/signal?view=msvc-170 + signal.signal( + signal.SIGBREAK, # type: ignore[attr-defined] + lambda sig, frame: signal.raise_signal(signal.SIGTERM), + ) + + threading.Thread(target=self.always_pong, daemon=True).start() + return self.real_target(sockets) + + def is_alive(self, timeout: float = 5) -> bool: + if not self.process.is_alive(): + return False + + return self.ping(timeout) + + def start(self) -> None: + self.process.start() + + def terminate(self) -> None: + if self.process.exitcode is None: # Process is still running + assert self.process.pid is not None + if os.name == "nt": # pragma: py-not-win32 + # Windows doesn't support SIGTERM. + # So send SIGBREAK, and then in process raise SIGTERM. + os.kill(self.process.pid, signal.CTRL_BREAK_EVENT) # type: ignore[attr-defined] + else: + os.kill(self.process.pid, signal.SIGTERM) + logger.info(f"Terminated child process [{self.process.pid}]") + + self.parent_conn.close() + self.child_conn.close() + + def kill(self) -> None: + # In Windows, the method will call `TerminateProcess` to kill the process. + # In Unix, the method will send SIGKILL to the process. + self.process.kill() + + def join(self) -> None: + logger.info(f"Waiting for child process [{self.process.pid}]") + self.process.join() + + @property + def pid(self) -> int | None: + return self.process.pid + + class Multiprocess: def __init__( self, @@ -32,39 +109,115 @@ def __init__( self.config = config self.target = target self.sockets = sockets - self.processes: list[SpawnProcess] = [] - self.should_exit = threading.Event() - self.pid = os.getpid() - - def signal_handler(self, sig: int, frame: FrameType | None) -> None: - """ - A signal handler that is registered with the parent process. - """ - self.should_exit.set() - def run(self) -> None: - self.startup() - self.should_exit.wait() - self.shutdown() + self.processes_num = config.workers + self.processes: list[Process] = [] - def startup(self) -> None: - message = f"Started parent process [{str(self.pid)}]" - color_message = "Started parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True)) - logger.info(message, extra={"color_message": color_message}) + self.should_exit = threading.Event() - for sig in HANDLED_SIGNALS: - signal.signal(sig, self.signal_handler) + self.signal_queue: list[int] = [] + for sig in SIGNALS: + signal.signal(sig, lambda sig, frame: self.signal_queue.append(sig)) - for _idx in range(self.config.workers): - process = get_subprocess(config=self.config, target=self.target, sockets=self.sockets) + def init_processes(self) -> None: + for _ in range(self.processes_num): + process = Process(self.config, self.target, self.sockets) process.start() self.processes.append(process) - def shutdown(self) -> None: + def terminate_all(self) -> None: for process in self.processes: process.terminate() + + def join_all(self) -> None: + for process in self.processes: + process.join() + + def restart_all(self) -> None: + for idx, process in enumerate(tuple(self.processes)): + process.terminate() process.join() + new_process = Process(self.config, self.target, self.sockets) + new_process.start() + self.processes[idx] = new_process + + def run(self) -> None: + message = f"Started parent process [{os.getpid()}]" + color_message = "Started parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True)) + logger.info(message, extra={"color_message": color_message}) - message = f"Stopping parent process [{str(self.pid)}]" - color_message = "Stopping parent process [{}]".format(click.style(str(self.pid), fg="cyan", bold=True)) + self.init_processes() + + while not self.should_exit.wait(0.5): + self.handle_signals() + self.keep_subprocess_alive() + + self.terminate_all() + self.join_all() + + message = f"Stopping parent process [{os.getpid()}]" + color_message = "Stopping parent process [{}]".format(click.style(str(os.getpid()), fg="cyan", bold=True)) logger.info(message, extra={"color_message": color_message}) + + def keep_subprocess_alive(self) -> None: + if self.should_exit.is_set(): + return # parent process is exiting, no need to keep subprocess alive + + for idx, process in enumerate(tuple(self.processes)): + if process.is_alive(): + continue + + process.kill() # process is hung, kill it + process.join() + + if self.should_exit.is_set(): + return + + logger.info(f"Child process [{process.pid}] died") + del self.processes[idx] + process = Process(self.config, self.target, self.sockets) + process.start() + self.processes.append(process) + + def handle_signals(self) -> None: + for sig in tuple(self.signal_queue): + self.signal_queue.remove(sig) + sig_name = SIGNALS[sig] + sig_handler = getattr(self, f"handle_{sig_name.lower()}", None) + if sig_handler is not None: + sig_handler() + else: # pragma: no cover + logger.debug(f"Received signal {sig_name}, but no handler is defined for it.") + + def handle_int(self) -> None: + logger.info("Received SIGINT, exiting.") + self.should_exit.set() + + def handle_term(self) -> None: + logger.info("Received SIGTERM, exiting.") + self.should_exit.set() + + def handle_break(self) -> None: # pragma: py-not-win32 + logger.info("Received SIGBREAK, exiting.") + self.should_exit.set() + + def handle_hup(self) -> None: # pragma: py-win32 + logger.info("Received SIGHUP, restarting processes.") + self.restart_all() + + def handle_ttin(self) -> None: # pragma: py-win32 + logger.info("Received SIGTTIN, increasing the number of processes.") + self.processes_num += 1 + process = Process(self.config, self.target, self.sockets) + process.start() + self.processes.append(process) + + def handle_ttou(self) -> None: # pragma: py-win32 + logger.info("Received SIGTTOU, decreasing number of processes.") + if self.processes_num <= 1: + logger.info("Already reached one process, cannot decrease the number of processes anymore.") + return + self.processes_num -= 1 + process = self.processes.pop() + process.terminate() + process.join()