Skip to content

Commit

Permalink
Allow fork in limited cases (#2624)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Dec 15, 2022
1 parent 064168f commit b276b91
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 55 deletions.
7 changes: 1 addition & 6 deletions sanic/application/ext.py
Expand Up @@ -8,11 +8,6 @@
if TYPE_CHECKING:
from sanic import Sanic

try:
from sanic_ext import Extend # type: ignore
except ImportError:
...


def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
if not app.config.AUTO_EXTEND:
Expand All @@ -33,7 +28,7 @@ def setup_ext(app: Sanic, *, fail: bool = False, **kwargs):
return

if not getattr(app, "_ext", None):
Ext: Extend = getattr(sanic_ext, "Extend")
Ext = getattr(sanic_ext, "Extend")
app._ext = Ext(app, **kwargs)

return app.ext
24 changes: 23 additions & 1 deletion sanic/compat.py
Expand Up @@ -3,10 +3,22 @@
import signal
import sys

from typing import Awaitable
from contextlib import contextmanager
from typing import Awaitable, Union

from multidict import CIMultiDict # type: ignore

from sanic.helpers import Default


if sys.version_info < (3, 8): # no cov
StartMethod = Union[Default, str]
else: # no cov
from typing import Literal

StartMethod = Union[
Default, Literal["fork"], Literal["forkserver"], Literal["spawn"]
]

OS_IS_WINDOWS = os.name == "nt"
UVLOOP_INSTALLED = False
Expand All @@ -19,6 +31,16 @@
pass


@contextmanager
def use_context(method: StartMethod):
from sanic import Sanic

orig = Sanic.start_method
Sanic.start_method = method
yield
Sanic.start_method = orig


def enable_windows_color_support():
import ctypes

Expand Down
20 changes: 13 additions & 7 deletions sanic/mixins/startup.py
Expand Up @@ -40,9 +40,9 @@
from sanic.application.motd import MOTD
from sanic.application.state import ApplicationServerInfo, Mode, ServerStage
from sanic.base.meta import SanicMeta
from sanic.compat import OS_IS_WINDOWS, is_atty
from sanic.compat import OS_IS_WINDOWS, StartMethod, is_atty
from sanic.exceptions import ServerKilled
from sanic.helpers import Default
from sanic.helpers import Default, _default
from sanic.http.constants import HTTP
from sanic.http.tls import get_ssl_context, process_to_context
from sanic.http.tls.context import SanicSSLContext
Expand Down Expand Up @@ -88,6 +88,7 @@ class StartupMixin(metaclass=SanicMeta):
state: ApplicationState
websocket_enabled: bool
multiplexer: WorkerMultiplexer
start_method: StartMethod = _default

def setup_loop(self):
if not self.asgi:
Expand Down Expand Up @@ -692,12 +693,17 @@ def should_auto_reload(cls) -> bool:
return any(app.state.auto_reload for app in cls._app_registry.values())

@classmethod
def _get_context(cls) -> BaseContext:
method = (
"spawn"
if "linux" not in sys.platform or cls.should_auto_reload()
else "fork"
def _get_startup_method(cls) -> str:
return (
cls.start_method
if not isinstance(cls.start_method, Default)
else "spawn"
)

@classmethod
def _get_context(cls) -> BaseContext:
method = cls._get_startup_method()
logger.debug("Creating multiprocessing context using '%s'", method)
return get_context(method)

@classmethod
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Expand Up @@ -8,7 +8,7 @@

from contextlib import suppress
from logging import LogRecord
from typing import List, Tuple
from typing import Any, Dict, List, Tuple
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -54,7 +54,7 @@ async def _handler(request):
"uuid": lambda: str(uuid.uuid1()),
}

CACHE = {}
CACHE: Dict[str, Any] = {}


class RouteStringGenerator:
Expand Down Expand Up @@ -147,6 +147,7 @@ def app(request):
for target, method_name in TouchUp._registry:
CACHE[method_name] = getattr(target, method_name)
app = Sanic(slugify.sub("-", request.node.name))

yield app
for target, method_name in TouchUp._registry:
setattr(target, method_name, CACHE[method_name])
Expand Down
10 changes: 5 additions & 5 deletions tests/test_app.py
Expand Up @@ -349,11 +349,11 @@ def test_get_app_does_not_exist():
with pytest.raises(
SanicException,
match="Sanic app name 'does-not-exist' not found.\n"
"App instantiation must occur outside "
"if __name__ == '__main__' "
"block or by using an AppLoader.\nSee "
"https://sanic.dev/en/guide/deployment/app-loader.html"
" for more details."
"App instantiation must occur outside "
"if __name__ == '__main__' "
"block or by using an AppLoader.\nSee "
"https://sanic.dev/en/guide/deployment/app-loader.html"
" for more details.",
):
Sanic.get_app("does-not-exist")

Expand Down
8 changes: 7 additions & 1 deletion tests/test_cli.py
Expand Up @@ -117,7 +117,13 @@ def test_error_with_path_as_instance_without_simple_arg(caplog):
),
)
def test_tls_options(cmd: Tuple[str, ...], caplog):
command = ["fake.server.app", *cmd, "--port=9999", "--debug"]
command = [
"fake.server.app",
*cmd,
"--port=9999",
"--debug",
"--single-process",
]
lines = capture(command, caplog)
assert "Goin' Fast @ https://127.0.0.1:9999" in lines

Expand Down
2 changes: 1 addition & 1 deletion tests/test_coffee.py
Expand Up @@ -39,7 +39,7 @@ async def shutdown(*_):
with patch("sys.stdout.isatty") as isatty:
isatty.return_value = True
with caplog.at_level(logging.DEBUG):
app.make_coffee()
app.make_coffee(single_process=True)

# Only in the regular logo
assert " ▄███ █████ ██ " not in caplog.text
Expand Down
3 changes: 1 addition & 2 deletions tests/test_create_task.py
Expand Up @@ -2,7 +2,6 @@
import sys

from threading import Event
from unittest.mock import Mock

import pytest

Expand Down Expand Up @@ -75,7 +74,7 @@ async def stop(app, _):

app.stop()

app.run()
app.run(single_process=True)


def test_named_task_called(app):
Expand Down
19 changes: 9 additions & 10 deletions tests/test_logging.py
Expand Up @@ -10,8 +10,7 @@
import sanic

from sanic import Sanic
from sanic.log import Colors
from sanic.log import LOGGING_CONFIG_DEFAULTS, logger
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, logger
from sanic.response import text


Expand Down Expand Up @@ -254,11 +253,11 @@ def log_info(request):


def test_colors_enum_format():
assert f'{Colors.END}' == Colors.END.value
assert f'{Colors.BOLD}' == Colors.BOLD.value
assert f'{Colors.BLUE}' == Colors.BLUE.value
assert f'{Colors.GREEN}' == Colors.GREEN.value
assert f'{Colors.PURPLE}' == Colors.PURPLE.value
assert f'{Colors.RED}' == Colors.RED.value
assert f'{Colors.SANIC}' == Colors.SANIC.value
assert f'{Colors.YELLOW}' == Colors.YELLOW.value
assert f"{Colors.END}" == Colors.END.value
assert f"{Colors.BOLD}" == Colors.BOLD.value
assert f"{Colors.BLUE}" == Colors.BLUE.value
assert f"{Colors.GREEN}" == Colors.GREEN.value
assert f"{Colors.PURPLE}" == Colors.PURPLE.value
assert f"{Colors.RED}" == Colors.RED.value
assert f"{Colors.SANIC}" == Colors.SANIC.value
assert f"{Colors.YELLOW}" == Colors.YELLOW.value
25 changes: 21 additions & 4 deletions tests/test_multiprocessing.py
Expand Up @@ -3,6 +3,7 @@
import pickle
import random
import signal
import sys

from asyncio import sleep

Expand All @@ -11,6 +12,7 @@
from sanic_testing.testing import HOST, PORT

from sanic import Blueprint, text
from sanic.compat import use_context
from sanic.log import logger
from sanic.server.socket import configure_socket

Expand All @@ -20,6 +22,10 @@
reason="SIGALRM is not implemented for this platform, we have to come "
"up with another timeout strategy to test these",
)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing(app):
"""Tests that the number of children we produce is correct"""
# Selects a number at random so we can spot check
Expand All @@ -37,7 +43,8 @@ def stop_on_alarm(*args):

signal.signal(signal.SIGALRM, stop_on_alarm)
signal.alarm(2)
app.run(HOST, 4120, workers=num_workers, debug=True)
with use_context("fork"):
app.run(HOST, 4120, workers=num_workers, debug=True)

assert len(process_list) == num_workers + 1

Expand Down Expand Up @@ -136,6 +143,10 @@ def stop_on_alarm(*args):
not hasattr(signal, "SIGALRM"),
reason="SIGALRM is not implemented for this platform",
)
@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_multiprocessing_with_blueprint(app):
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
Expand All @@ -155,7 +166,8 @@ def stop_on_alarm(*args):

bp = Blueprint("test_text")
app.blueprint(bp)
app.run(HOST, 4121, workers=num_workers, debug=True)
with use_context("fork"):
app.run(HOST, 4121, workers=num_workers, debug=True)

assert len(process_list) == num_workers + 1

Expand Down Expand Up @@ -213,6 +225,10 @@ def test_pickle_app_with_static(app, protocol):
up_p_app.run(single_process=True)


@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_main_process_event(app, caplog):
# Selects a number at random so we can spot check
num_workers = random.choice(range(2, multiprocessing.cpu_count() * 2 + 1))
Expand All @@ -235,8 +251,9 @@ def main_process_start2(app, loop):
def main_process_stop2(app, loop):
logger.info("main_process_stop")

with caplog.at_level(logging.INFO):
app.run(HOST, PORT, workers=num_workers)
with use_context("fork"):
with caplog.at_level(logging.INFO):
app.run(HOST, PORT, workers=num_workers)

assert (
caplog.record_tuples.count(("sanic.root", 20, "main_process_start"))
Expand Down
7 changes: 1 addition & 6 deletions tests/test_request_stream.py
@@ -1,8 +1,5 @@
import asyncio

from contextlib import closing
from socket import socket

import pytest

from sanic import Sanic
Expand Down Expand Up @@ -623,6 +620,4 @@ async def read_chunk():
res = await read_chunk()
assert res == None

# Use random port for tests
with closing(socket()) as sock:
app.run(access_log=False)
app.run(access_log=False, single_process=True)
18 changes: 15 additions & 3 deletions tests/test_tls.py
Expand Up @@ -2,6 +2,7 @@
import os
import ssl
import subprocess
import sys

from contextlib import contextmanager
from multiprocessing import Event
Expand All @@ -17,6 +18,7 @@

from sanic import Sanic
from sanic.application.constants import Mode
from sanic.compat import use_context
from sanic.constants import LocalCertCreator
from sanic.exceptions import SanicException
from sanic.helpers import _default
Expand Down Expand Up @@ -426,7 +428,12 @@ def stop(*args):
app.stop()

with caplog.at_level(logging.INFO):
app.run(host="127.0.0.1", port=42102, ssl=[localhost_dir, sanic_dir])
app.run(
host="127.0.0.1",
port=42102,
ssl=[localhost_dir, sanic_dir],
single_process=True,
)

logmsg = [
m for s, l, m in caplog.record_tuples if m.startswith("Certificate")
Expand Down Expand Up @@ -642,6 +649,10 @@ def test_sanic_ssl_context_create():
assert isinstance(sanic_context, SanicSSLContext)


@pytest.mark.skipif(
sys.platform not in ("linux", "darwin"),
reason="This test requires fork context",
)
def test_ssl_in_multiprocess_mode(app: Sanic, caplog):

ssl_dict = {"cert": localhost_cert, "key": localhost_key}
Expand All @@ -657,8 +668,9 @@ async def shutdown(app):
app.stop()

assert not event.is_set()
with caplog.at_level(logging.INFO):
app.run(ssl=ssl_dict)
with use_context("fork"):
with caplog.at_level(logging.INFO):
app.run(ssl=ssl_dict)
assert event.is_set()

assert (
Expand Down

0 comments on commit b276b91

Please sign in to comment.