From b276b91c21256b43f07792221d99aa28cb5bd3f5 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Thu, 15 Dec 2022 11:49:26 +0200 Subject: [PATCH] Allow fork in limited cases (#2624) --- sanic/application/ext.py | 7 +------ sanic/compat.py | 24 +++++++++++++++++++++++- sanic/mixins/startup.py | 20 +++++++++++++------- tests/conftest.py | 5 +++-- tests/test_app.py | 10 +++++----- tests/test_cli.py | 8 +++++++- tests/test_coffee.py | 2 +- tests/test_create_task.py | 3 +-- tests/test_logging.py | 19 +++++++++---------- tests/test_multiprocessing.py | 25 +++++++++++++++++++++---- tests/test_request_stream.py | 7 +------ tests/test_tls.py | 18 +++++++++++++++--- tests/test_unix_socket.py | 19 ++++++++++++++----- tests/worker/test_manager.py | 9 ++++++++- tests/worker/test_multiplexer.py | 10 +++++++++- tests/worker/test_startup.py | 25 +++++++++++++++++++++++++ 16 files changed, 156 insertions(+), 55 deletions(-) create mode 100644 tests/worker/test_startup.py diff --git a/sanic/application/ext.py b/sanic/application/ext.py index 405d0a7fe1..0f4bdfb173 100644 --- a/sanic/application/ext.py +++ b/sanic/application/ext.py @@ -8,11 +8,6 @@ if TYPE_CHECKING: from sanic import Sanic - try: - from sanic_ext import Extend # type: ignore - except ImportError: - ... - def setup_ext(app: Sanic, *, fail: bool = False, **kwargs): if not app.config.AUTO_EXTEND: @@ -33,7 +28,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs): return if not getattr(app, "_ext", None): - Ext: Extend = getattr(sanic_ext, "Extend") + Ext = getattr(sanic_ext, "Extend") app._ext = Ext(app, **kwargs) return app.ext diff --git a/sanic/compat.py b/sanic/compat.py index 4ea2ed91ff..26b0bcde2e 100644 --- a/sanic/compat.py +++ b/sanic/compat.py @@ -3,10 +3,22 @@ import signal import sys -from typing import Awaitable +from contextlib import contextmanager +from typing import Awaitable, Union from multidict import CIMultiDict # type: ignore +from sanic.helpers import Default + + +if sys.version_info < (3, 8): # no cov + StartMethod = Union[Default, str] +else: # no cov + from typing import Literal + + StartMethod = Union[ + Default, Literal["fork"], Literal["forkserver"], Literal["spawn"] + ] OS_IS_WINDOWS = os.name == "nt" UVLOOP_INSTALLED = False @@ -19,6 +31,16 @@ pass +@contextmanager +def use_context(method: StartMethod): + from sanic import Sanic + + orig = Sanic.start_method + Sanic.start_method = method + yield + Sanic.start_method = orig + + def enable_windows_color_support(): import ctypes diff --git a/sanic/mixins/startup.py b/sanic/mixins/startup.py index 78abb88439..140ecd2258 100644 --- a/sanic/mixins/startup.py +++ b/sanic/mixins/startup.py @@ -40,9 +40,9 @@ from sanic.application.motd import MOTD from sanic.application.state import ApplicationServerInfo, Mode, ServerStage from sanic.base.meta import SanicMeta -from sanic.compat import OS_IS_WINDOWS, is_atty +from sanic.compat import OS_IS_WINDOWS, StartMethod, is_atty from sanic.exceptions import ServerKilled -from sanic.helpers import Default +from sanic.helpers import Default, _default from sanic.http.constants import HTTP from sanic.http.tls import get_ssl_context, process_to_context from sanic.http.tls.context import SanicSSLContext @@ -88,6 +88,7 @@ class StartupMixin(metaclass=SanicMeta): state: ApplicationState websocket_enabled: bool multiplexer: WorkerMultiplexer + start_method: StartMethod = _default def setup_loop(self): if not self.asgi: @@ -692,12 +693,17 @@ def should_auto_reload(cls) -> bool: return any(app.state.auto_reload for app in cls._app_registry.values()) @classmethod - def _get_context(cls) -> BaseContext: - method = ( - "spawn" - if "linux" not in sys.platform or cls.should_auto_reload() - else "fork" + def _get_startup_method(cls) -> str: + return ( + cls.start_method + if not isinstance(cls.start_method, Default) + else "spawn" ) + + @classmethod + def _get_context(cls) -> BaseContext: + method = cls._get_startup_method() + logger.debug("Creating multiprocessing context using '%s'", method) return get_context(method) @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index 18e74daf81..a84c34dc7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from contextlib import suppress from logging import LogRecord -from typing import List, Tuple +from typing import Any, Dict, List, Tuple from unittest.mock import MagicMock import pytest @@ -54,7 +54,7 @@ async def _handler(request): "uuid": lambda: str(uuid.uuid1()), } -CACHE = {} +CACHE: Dict[str, Any] = {} class RouteStringGenerator: @@ -147,6 +147,7 @@ def app(request): for target, method_name in TouchUp._registry: CACHE[method_name] = getattr(target, method_name) app = Sanic(slugify.sub("-", request.node.name)) + yield app for target, method_name in TouchUp._registry: setattr(target, method_name, CACHE[method_name]) diff --git a/tests/test_app.py b/tests/test_app.py index b6cf98e395..f5eec8e2ca 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -349,11 +349,11 @@ def test_get_app_does_not_exist(): with pytest.raises( SanicException, match="Sanic app name 'does-not-exist' not found.\n" - "App instantiation must occur outside " - "if __name__ == '__main__' " - "block or by using an AppLoader.\nSee " - "https://sanic.dev/en/guide/deployment/app-loader.html" - " for more details." + "App instantiation must occur outside " + "if __name__ == '__main__' " + "block or by using an AppLoader.\nSee " + "https://sanic.dev/en/guide/deployment/app-loader.html" + " for more details.", ): Sanic.get_app("does-not-exist") diff --git a/tests/test_cli.py b/tests/test_cli.py index fd055c50ab..47979fa264 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -117,7 +117,13 @@ def test_error_with_path_as_instance_without_simple_arg(caplog): ), ) def test_tls_options(cmd: Tuple[str, ...], caplog): - command = ["fake.server.app", *cmd, "--port=9999", "--debug"] + command = [ + "fake.server.app", + *cmd, + "--port=9999", + "--debug", + "--single-process", + ] lines = capture(command, caplog) assert "Goin' Fast @ https://127.0.0.1:9999" in lines diff --git a/tests/test_coffee.py b/tests/test_coffee.py index 6143f17f92..43864b3d48 100644 --- a/tests/test_coffee.py +++ b/tests/test_coffee.py @@ -39,7 +39,7 @@ async def shutdown(*_): with patch("sys.stdout.isatty") as isatty: isatty.return_value = True with caplog.at_level(logging.DEBUG): - app.make_coffee() + app.make_coffee(single_process=True) # Only in the regular logo assert " ▄███ █████ ██ " not in caplog.text diff --git a/tests/test_create_task.py b/tests/test_create_task.py index a11bc302be..946c7aaa5f 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -2,7 +2,6 @@ import sys from threading import Event -from unittest.mock import Mock import pytest @@ -75,7 +74,7 @@ async def stop(app, _): app.stop() - app.run() + app.run(single_process=True) def test_named_task_called(app): diff --git a/tests/test_logging.py b/tests/test_logging.py index 23a7d9a336..3a7ba4f5ac 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -10,8 +10,7 @@ import sanic from sanic import Sanic -from sanic.log import Colors -from sanic.log import LOGGING_CONFIG_DEFAULTS, logger +from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, logger from sanic.response import text @@ -254,11 +253,11 @@ def log_info(request): def test_colors_enum_format(): - assert f'{Colors.END}' == Colors.END.value - assert f'{Colors.BOLD}' == Colors.BOLD.value - assert f'{Colors.BLUE}' == Colors.BLUE.value - assert f'{Colors.GREEN}' == Colors.GREEN.value - assert f'{Colors.PURPLE}' == Colors.PURPLE.value - assert f'{Colors.RED}' == Colors.RED.value - assert f'{Colors.SANIC}' == Colors.SANIC.value - assert f'{Colors.YELLOW}' == Colors.YELLOW.value + assert f"{Colors.END}" == Colors.END.value + assert f"{Colors.BOLD}" == Colors.BOLD.value + assert f"{Colors.BLUE}" == Colors.BLUE.value + assert f"{Colors.GREEN}" == Colors.GREEN.value + assert f"{Colors.PURPLE}" == Colors.PURPLE.value + assert f"{Colors.RED}" == Colors.RED.value + assert f"{Colors.SANIC}" == Colors.SANIC.value + assert f"{Colors.YELLOW}" == Colors.YELLOW.value diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 877d3410e7..6333cf9b90 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -3,6 +3,7 @@ import pickle import random import signal +import sys from asyncio import sleep @@ -11,6 +12,7 @@ from sanic_testing.testing import HOST, PORT from sanic import Blueprint, text +from sanic.compat import use_context from sanic.log import logger from sanic.server.socket import configure_socket @@ -20,6 +22,10 @@ reason="SIGALRM is not implemented for this platform, we have to come " "up with another timeout strategy to test these", ) +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_multiprocessing(app): """Tests that the number of children we produce is correct""" # Selects a number at random so we can spot check @@ -37,7 +43,8 @@ def stop_on_alarm(*args): signal.signal(signal.SIGALRM, stop_on_alarm) signal.alarm(2) - app.run(HOST, 4120, workers=num_workers, debug=True) + with use_context("fork"): + app.run(HOST, 4120, workers=num_workers, debug=True) assert len(process_list) == num_workers + 1 @@ -136,6 +143,10 @@ def stop_on_alarm(*args): not hasattr(signal, "SIGALRM"), reason="SIGALRM is not implemented for this platform", ) +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_multiprocessing_with_blueprint(app): # Selects a number at random so we can spot check num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) @@ -155,7 +166,8 @@ def stop_on_alarm(*args): bp = Blueprint("test_text") app.blueprint(bp) - app.run(HOST, 4121, workers=num_workers, debug=True) + with use_context("fork"): + app.run(HOST, 4121, workers=num_workers, debug=True) assert len(process_list) == num_workers + 1 @@ -213,6 +225,10 @@ def test_pickle_app_with_static(app, protocol): up_p_app.run(single_process=True) +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_main_process_event(app, caplog): # Selects a number at random so we can spot check num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1)) @@ -235,8 +251,9 @@ def main_process_start2(app, loop): def main_process_stop2(app, loop): logger.info("main_process_stop") - with caplog.at_level(logging.INFO): - app.run(HOST, PORT, workers=num_workers) + with use_context("fork"): + with caplog.at_level(logging.INFO): + app.run(HOST, PORT, workers=num_workers) assert ( caplog.record_tuples.count(("sanic.root", 20, "main_process_start")) diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index a77baf487f..1513f87876 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -1,8 +1,5 @@ import asyncio -from contextlib import closing -from socket import socket - import pytest from sanic import Sanic @@ -623,6 +620,4 @@ async def read_chunk(): res = await read_chunk() assert res == None - # Use random port for tests - with closing(socket()) as sock: - app.run(access_log=False) + app.run(access_log=False, single_process=True) diff --git a/tests/test_tls.py b/tests/test_tls.py index 297033d77f..e256178cf7 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -2,6 +2,7 @@ import os import ssl import subprocess +import sys from contextlib import contextmanager from multiprocessing import Event @@ -17,6 +18,7 @@ from sanic import Sanic from sanic.application.constants import Mode +from sanic.compat import use_context from sanic.constants import LocalCertCreator from sanic.exceptions import SanicException from sanic.helpers import _default @@ -426,7 +428,12 @@ def stop(*args): app.stop() with caplog.at_level(logging.INFO): - app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir]) + app.run( + host="127.0.0.1", + port=42102, + ssl=[localhost_dir, sanic_dir], + single_process=True, + ) logmsg = [ m for s, l, m in caplog.record_tuples if m.startswith("Certificate") @@ -642,6 +649,10 @@ def test_sanic_ssl_context_create(): assert isinstance(sanic_context, SanicSSLContext) +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_ssl_in_multiprocess_mode(app: Sanic, caplog): ssl_dict = {"cert": localhost_cert, "key": localhost_key} @@ -657,8 +668,9 @@ async def shutdown(app): app.stop() assert not event.is_set() - with caplog.at_level(logging.INFO): - app.run(ssl=ssl_dict) + with use_context("fork"): + with caplog.at_level(logging.INFO): + app.run(ssl=ssl_dict) assert event.is_set() assert ( diff --git a/tests/test_unix_socket.py b/tests/test_unix_socket.py index adb80b9cc9..4760e0e4c5 100644 --- a/tests/test_unix_socket.py +++ b/tests/test_unix_socket.py @@ -1,6 +1,7 @@ # import asyncio import logging import os +import sys from asyncio import AbstractEventLoop, sleep from string import ascii_lowercase @@ -12,6 +13,7 @@ from pytest import LogCaptureFixture from sanic import Sanic +from sanic.compat import use_context from sanic.request import Request from sanic.response import text @@ -174,7 +176,9 @@ def handler(request: Request): async def client(app: Sanic, loop: AbstractEventLoop): try: - async with httpx.AsyncClient(uds=SOCKPATH) as client: + + transport = httpx.AsyncHTTPTransport(uds=SOCKPATH) + async with httpx.AsyncClient(transport=transport) as client: r = await client.get("http://myhost.invalid/") assert r.status_code == 200 assert r.text == os.path.abspath(SOCKPATH) @@ -183,11 +187,16 @@ async def client(app: Sanic, loop: AbstractEventLoop): app.stop() +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_unix_connection_multiple_workers(): - app_multi = Sanic(name="test") - app_multi.get("/")(handler) - app_multi.listener("after_server_start")(client) - app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2) + with use_context("fork"): + app_multi = Sanic(name="test") + app_multi.get("/")(handler) + app_multi.listener("after_server_start")(client) + app_multi.run(host="myhost.invalid", unix=SOCKPATH, workers=2) # @pytest.mark.xfail( diff --git a/tests/worker/test_manager.py b/tests/worker/test_manager.py index d17b9971d3..85f673dd81 100644 --- a/tests/worker/test_manager.py +++ b/tests/worker/test_manager.py @@ -1,13 +1,20 @@ from logging import ERROR, INFO -from signal import SIGINT, SIGKILL +from signal import SIGINT from unittest.mock import Mock, call, patch import pytest +from sanic.compat import OS_IS_WINDOWS from sanic.exceptions import ServerKilled from sanic.worker.manager import WorkerManager +if not OS_IS_WINDOWS: + from signal import SIGKILL +else: + SIGKILL = SIGINT + + def fake_serve(): ... diff --git a/tests/worker/test_multiplexer.py b/tests/worker/test_multiplexer.py index e0b1a3688d..075a677ca4 100644 --- a/tests/worker/test_multiplexer.py +++ b/tests/worker/test_multiplexer.py @@ -1,3 +1,5 @@ +import sys + from multiprocessing import Event from os import environ, getpid from typing import Any, Dict, Type, Union @@ -6,6 +8,7 @@ import pytest from sanic import Sanic +from sanic.compat import use_context from sanic.worker.multiplexer import WorkerMultiplexer from sanic.worker.state import WorkerState @@ -28,6 +31,10 @@ def m(monitor_publisher, worker_state): del environ["SANIC_WORKER_NAME"] +@pytest.mark.skipif( + sys.platform not in ("linux", "darwin"), + reason="This test requires fork context", +) def test_has_multiplexer_default(app: Sanic): event = Event() @@ -41,7 +48,8 @@ def stop(app): app.shared_ctx.event.set() app.stop() - app.run() + with use_context("fork"): + app.run() assert event.is_set() diff --git a/tests/worker/test_startup.py b/tests/worker/test_startup.py new file mode 100644 index 0000000000..c66d9b099b --- /dev/null +++ b/tests/worker/test_startup.py @@ -0,0 +1,25 @@ +from unittest.mock import patch + +import pytest + +from sanic import Sanic + + +@pytest.mark.parametrize( + "start_method,platform,expected", + ( + (None, "linux", "spawn"), + (None, "other", "spawn"), + ("fork", "linux", "fork"), + ("fork", "other", "fork"), + ("forkserver", "linux", "forkserver"), + ("forkserver", "other", "forkserver"), + ("spawn", "linux", "spawn"), + ("spawn", "other", "spawn"), + ), +) +def test_get_context(start_method, platform, expected): + if start_method: + Sanic.start_method = start_method + with patch("sys.platform", platform): + assert Sanic._get_startup_method() == expected