diff --git a/.coveragerc b/.coveragerc index 228560650c..1a042c34e9 100644 --- a/.coveragerc +++ b/.coveragerc @@ -20,6 +20,7 @@ exclude_lines = noqa NOQA pragma: no cover + TYPE_CHECKING omit = site-packages sanic/__main__.py diff --git a/pyproject.toml b/pyproject.toml index 578c40c62f..7c5e29600d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,3 +16,10 @@ lines_after_imports = 2 lines_between_types = 1 multi_line_output = 3 profile = "black" + +[[tool.mypy.overrides]] +module = [ + "trustme.*", + "sanic_routing.*", +] +ignore_missing_imports = true diff --git a/sanic/app.py b/sanic/app.py index c928f028ce..a58c5fe69a 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -43,11 +43,8 @@ from urllib.parse import urlencode, urlunparse from warnings import filterwarnings -from sanic_routing.exceptions import ( # type: ignore - FinalizationError, - NotFound, -) -from sanic_routing.route import Route # type: ignore +from sanic_routing.exceptions import FinalizationError, NotFound +from sanic_routing.route import Route from sanic.application.ext import setup_ext from sanic.application.state import ApplicationState, Mode, ServerStage @@ -64,6 +61,7 @@ URLBuildError, ) from sanic.handlers import ErrorHandler +from sanic.helpers import _default from sanic.http import Stage from sanic.log import ( LOGGING_CONFIG_DEFAULTS, @@ -92,7 +90,7 @@ from sanic.touchup import TouchUp, TouchUpMeta -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: try: from sanic_ext import Extend # type: ignore from sanic_ext.extensions.base import Extension # type: ignore @@ -949,6 +947,7 @@ async def handle_request(self, request: Request): # no cov "response": response, }, ) + ... await response.send(end_stream=True) elif isinstance(response, ResponseStream): resp = await response(request) @@ -1532,8 +1531,10 @@ async def _startup(self): if hasattr(self, "_ext"): self.ext._display() - if self.state.is_debug: + if self.state.is_debug and self.config.TOUCHUP is not True: self.config.TOUCHUP = False + elif self.config.TOUCHUP is _default: + self.config.TOUCHUP = True # Setup routers self.signalize(self.config.TOUCHUP) diff --git a/sanic/application/constants.py b/sanic/application/constants.py new file mode 100644 index 0000000000..9d46cb8e60 --- /dev/null +++ b/sanic/application/constants.py @@ -0,0 +1,23 @@ +from enum import Enum, IntEnum, auto + + +class StrEnum(str, Enum): + def _generate_next_value_(name: str, *args) -> str: # type: ignore + return name.lower() + + +class Server(StrEnum): + SANIC = auto() + ASGI = auto() + GUNICORN = auto() + + +class Mode(StrEnum): + PRODUCTION = auto() + DEBUG = auto() + + +class ServerStage(IntEnum): + STOPPED = auto() + PARTIAL = auto() + SERVING = auto() diff --git a/sanic/application/ext.py b/sanic/application/ext.py index deb7c5d4c2..eac1e3179a 100644 --- a/sanic/application/ext.py +++ b/sanic/application/ext.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic try: diff --git a/sanic/application/spinner.py b/sanic/application/spinner.py new file mode 100644 index 0000000000..e89513ea42 --- /dev/null +++ b/sanic/application/spinner.py @@ -0,0 +1,86 @@ +import os +import sys +import time + +from contextlib import contextmanager +from queue import Queue +from threading import Thread + + +if os.name == "nt": # noqa + import ctypes # noqa + + class _CursorInfo(ctypes.Structure): + _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)] + + +class Spinner: # noqa + def __init__(self, message: str) -> None: + self.message = message + self.queue: Queue[int] = Queue() + self.spinner = self.cursor() + self.thread = Thread(target=self.run) + + def start(self): + self.queue.put(1) + self.thread.start() + self.hide() + + def run(self): + while self.queue.get(): + output = f"\r{self.message} [{next(self.spinner)}]" + sys.stdout.write(output) + sys.stdout.flush() + time.sleep(0.1) + self.queue.put(1) + + def stop(self): + self.queue.put(0) + self.thread.join() + self.show() + + @staticmethod + def cursor(): + while True: + for cursor in "|/-\\": + yield cursor + + @staticmethod + def hide(): + if os.name == "nt": + ci = _CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + ci.visible = False + ctypes.windll.kernel32.SetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + elif os.name == "posix": + sys.stdout.write("\033[?25l") + sys.stdout.flush() + + @staticmethod + def show(): + if os.name == "nt": + ci = _CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + ci.visible = True + ctypes.windll.kernel32.SetConsoleCursorInfo( + handle, ctypes.byref(ci) + ) + elif os.name == "posix": + sys.stdout.write("\033[?25h") + sys.stdout.flush() + + +@contextmanager +def loading(message: str = "Loading"): # noqa + spinner = Spinner(message) + spinner.start() + yield + spinner.stop() diff --git a/sanic/application/state.py b/sanic/application/state.py index 5975c2a6f2..f308f2c6ec 100644 --- a/sanic/application/state.py +++ b/sanic/application/state.py @@ -3,42 +3,20 @@ import logging from dataclasses import dataclass, field -from enum import Enum, IntEnum, auto from pathlib import Path from socket import socket from ssl import SSLContext from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from sanic.application.constants import Mode, Server, ServerStage from sanic.log import VerbosityFilter, logger from sanic.server.async_server import AsyncioServer -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic -class StrEnum(str, Enum): - def _generate_next_value_(name: str, *args) -> str: # type: ignore - return name.lower() - - -class Server(StrEnum): - SANIC = auto() - ASGI = auto() - GUNICORN = auto() - - -class Mode(StrEnum): - PRODUCTION = auto() - DEBUG = auto() - - -class ServerStage(IntEnum): - STOPPED = auto() - PARTIAL = auto() - SERVING = auto() - - @dataclass class ApplicationServerInfo: settings: Dict[str, Any] diff --git a/sanic/asgi.py b/sanic/asgi.py index 3dbd95a702..10357ae876 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -17,7 +17,7 @@ from sanic.server.websockets.connection import WebSocketConnection -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic diff --git a/sanic/blueprint_group.py b/sanic/blueprint_group.py index b16d8c58e6..a9b514106b 100644 --- a/sanic/blueprint_group.py +++ b/sanic/blueprint_group.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, List, Optional, Union -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic.blueprints import Blueprint diff --git a/sanic/blueprints.py b/sanic/blueprints.py index df4501ddcc..2b9e52e09e 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -21,8 +21,8 @@ Union, ) -from sanic_routing.exceptions import NotFound # type: ignore -from sanic_routing.route import Route # type: ignore +from sanic_routing.exceptions import NotFound +from sanic_routing.route import Route from sanic.base.root import BaseSanic from sanic.blueprint_group import BlueprintGroup @@ -36,7 +36,7 @@ ) -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic diff --git a/sanic/cli/app.py b/sanic/cli/app.py index b90536a0fa..e012c8b35b 100644 --- a/sanic/cli/app.py +++ b/sanic/cli/app.py @@ -58,10 +58,13 @@ def __init__(self) -> None: os.environ.get("SANIC_RELOADER_PROCESS", "") != "true" ) self.args: List[Any] = [] + self.groups: List[Group] = [] def attach(self): for group in Group._registry: - group.create(self.parser).attach() + instance = group.create(self.parser) + instance.attach() + self.groups.append(instance) def run(self): # This is to provide backwards compat -v to display version @@ -81,9 +84,13 @@ def run(self): try: app = self._get_app() kwargs = self._build_run_kwargs() - app.run(**kwargs) except ValueError: error_logger.exception("Failed to run app") + else: + for http_version in self.args.http: + app.prepare(**kwargs, version=http_version) + + Sanic.serve() def _precheck(self): # # Custom TLS mismatch handling for better diagnostics @@ -163,11 +170,14 @@ def _get_app(self): " Example File: project/sanic_server.py -> app\n" " Example Module: project.sanic_server.app" ) + sys.exit(1) else: raise e return app def _build_run_kwargs(self): + for group in self.groups: + group.prepare(self.args) ssl: Union[None, dict, str, list] = [] if self.args.tlshost: ssl.append(None) @@ -192,6 +202,7 @@ def _build_run_kwargs(self): "unix": self.args.unix, "verbosity": self.args.verbosity or 0, "workers": self.args.workers, + "auto_tls": self.args.auto_tls, } for maybe_arg in ("auto_reload", "dev"): @@ -201,4 +212,5 @@ def _build_run_kwargs(self): if self.args.path: kwargs["auto_reload"] = True kwargs["reload_dir"] = self.args.path + return kwargs diff --git a/sanic/cli/arguments.py b/sanic/cli/arguments.py index 6ee084346d..cde125fb98 100644 --- a/sanic/cli/arguments.py +++ b/sanic/cli/arguments.py @@ -3,9 +3,10 @@ from argparse import ArgumentParser, _ArgumentGroup from typing import List, Optional, Type, Union -from sanic_routing import __version__ as __routing_version__ # type: ignore +from sanic_routing import __version__ as __routing_version__ from sanic import __version__ +from sanic.http.constants import HTTP class Group: @@ -38,6 +39,9 @@ def add_bool_arguments(self, *args, **kwargs): "--no-" + args[0][2:], *args[1:], action="store_false", **kwargs ) + def prepare(self, args) -> None: + ... + class GeneralGroup(Group): name = None @@ -83,6 +87,44 @@ def attach(self): ) +class HTTPVersionGroup(Group): + name = "HTTP version" + + def attach(self): + http_values = [http.value for http in HTTP.__members__.values()] + + self.container.add_argument( + "--http", + dest="http", + action="append", + choices=http_values, + type=int, + help=( + "Which HTTP version to use: HTTP/1.1 or HTTP/3. Value should\n" + "be either 1, or 3. [default 1]" + ), + ) + self.container.add_argument( + "-1", + dest="http", + action="append_const", + const=1, + help=("Run Sanic server using HTTP/1.1"), + ) + self.container.add_argument( + "-3", + dest="http", + action="append_const", + const=3, + help=("Run Sanic server using HTTP/3"), + ) + + def prepare(self, args): + if not args.http: + args.http = [1] + args.http = tuple(sorted(set(map(HTTP, args.http)), reverse=True)) + + class SocketGroup(Group): name = "Socket binding" @@ -92,7 +134,6 @@ def attach(self): "--host", dest="host", type=str, - default="127.0.0.1", help="Host address [default 127.0.0.1]", ) self.container.add_argument( @@ -100,7 +141,6 @@ def attach(self): "--port", dest="port", type=int, - default=8000, help="Port to serve on [default 8000]", ) self.container.add_argument( @@ -180,11 +220,7 @@ def attach(self): "--debug", dest="debug", action="store_true", - help=( - "Run the server in DEBUG mode. It includes DEBUG logging,\n" - "additional context on exceptions, and other settings\n" - "not-safe for PRODUCTION, but helpful for debugging problems." - ), + help="Run the server in debug mode", ) self.container.add_argument( "-r", @@ -209,7 +245,16 @@ def attach(self): "--dev", dest="dev", action="store_true", - help=("debug + auto reload."), + help=("debug + auto reload"), + ) + self.container.add_argument( + "--auto-tls", + dest="auto_tls", + action="store_true", + help=( + "Create a temporary TLS certificate for local development " + "(requires mkcert or trustme)" + ), ) diff --git a/sanic/config.py b/sanic/config.py index a07a0f4c03..fd63ca5479 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Sequence, Union +from sanic.constants import LocalCertCreator from sanic.errorpages import DEFAULT_FORMAT, check_error_format from sanic.helpers import Default, _default from sanic.http import Http @@ -26,6 +27,10 @@ "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds "KEEP_ALIVE": True, + "LOCAL_CERT_CREATOR": LocalCertCreator.AUTO, + "LOCAL_TLS_KEY": _default, + "LOCAL_TLS_CERT": _default, + "LOCALHOST": "localhost", "MOTD": True, "MOTD_DISPLAY": {}, "NOISY_EXCEPTIONS": False, @@ -38,7 +43,8 @@ "REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds - "TOUCHUP": True, + "TLS_CERT_PASSWORD": "", + "TOUCHUP": _default, "USE_UVLOOP": _default, "WEBSOCKET_MAX_SIZE": 2**20, # 1 megabyte "WEBSOCKET_PING_INTERVAL": 20, @@ -69,9 +75,13 @@ class Config(dict, metaclass=DescriptorMeta): GRACEFUL_SHUTDOWN_TIMEOUT: float KEEP_ALIVE_TIMEOUT: int KEEP_ALIVE: bool - NOISY_EXCEPTIONS: bool + LOCAL_CERT_CREATOR: Union[str, LocalCertCreator] + LOCAL_TLS_KEY: Union[Path, str, Default] + LOCAL_TLS_CERT: Union[Path, str, Default] + LOCALHOST: str MOTD: bool MOTD_DISPLAY: Dict[str, str] + NOISY_EXCEPTIONS: bool PROXIES_COUNT: Optional[int] REAL_IP_HEADER: Optional[str] REGISTER: bool @@ -82,7 +92,8 @@ class Config(dict, metaclass=DescriptorMeta): REQUEST_TIMEOUT: int RESPONSE_TIMEOUT: int SERVER_NAME: str - TOUCHUP: bool + TLS_CERT_PASSWORD: str + TOUCHUP: Union[Default, bool] USE_UVLOOP: Union[Default, bool] WEBSOCKET_MAX_SIZE: int WEBSOCKET_PING_INTERVAL: int @@ -157,13 +168,19 @@ def _post_set(self, attr, value) -> None: "REQUEST_MAX_SIZE", ): self._configure_header_size() - elif attr == "LOGO": - self._LOGO = value - deprecation( - "Setting the config.LOGO is deprecated and will no longer " - "be supported starting in v22.6.", - 22.6, - ) + if attr == "LOGO": + self._LOGO = value + deprecation( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + 22.6, + ) + elif attr == "LOCAL_CERT_CREATOR" and not isinstance( + self.LOCAL_CERT_CREATOR, LocalCertCreator + ): + self.LOCAL_CERT_CREATOR = LocalCertCreator[ + self.LOCAL_CERT_CREATOR.upper() + ] @property def LOGO(self): diff --git a/sanic/constants.py b/sanic/constants.py index 80f1d2a9bf..52ec50ef00 100644 --- a/sanic/constants.py +++ b/sanic/constants.py @@ -24,5 +24,16 @@ def __str__(self) -> str: DELETE = auto() +class LocalCertCreator(str, Enum): + def _generate_next_value_(name, start, count, last_values): + return name.upper() + + AUTO = auto() + TRUSTME = auto() + MKCERT = auto() + + HTTP_METHODS = tuple(HTTPMethod.__members__.values()) DEFAULT_HTTP_CONTENT_TYPE = "application/octet-stream" +DEFAULT_LOCAL_TLS_KEY = "key.pem" +DEFAULT_LOCAL_TLS_CERT = "cert.pem" diff --git a/sanic/http/__init__.py b/sanic/http/__init__.py new file mode 100644 index 0000000000..8a96102926 --- /dev/null +++ b/sanic/http/__init__.py @@ -0,0 +1,5 @@ +from .constants import Stage +from .http1 import Http + + +__all__ = ("Http", "Stage") diff --git a/sanic/http/constants.py b/sanic/http/constants.py new file mode 100644 index 0000000000..df3eebeb5e --- /dev/null +++ b/sanic/http/constants.py @@ -0,0 +1,29 @@ +from enum import Enum, IntEnum + + +class Stage(Enum): + """ + Enum for representing the stage of the request/response cycle + + | ``IDLE`` Waiting for request + | ``REQUEST`` Request headers being received + | ``HANDLER`` Headers done, handler running + | ``RESPONSE`` Response headers sent, body in progress + | ``FAILED`` Unrecoverable state (error while sending response) + | + """ + + IDLE = 0 # Waiting for request + REQUEST = 1 # Request headers being received + HANDLER = 3 # Headers done, handler running + RESPONSE = 4 # Response headers sent, body in progress + FAILED = 100 # Unrecoverable state (error while sending response) + + +class HTTP(IntEnum): + VERSION_1 = 1 + VERSION_3 = 3 + + def display(self) -> str: + value = 1.1 if self.value == 1 else self.value + return f"HTTP/{value}" diff --git a/sanic/http.py b/sanic/http/http1.py similarity index 96% rename from sanic/http.py rename to sanic/http/http1.py index b63e243d3c..1f7870ee64 100644 --- a/sanic/http.py +++ b/sanic/http/http1.py @@ -3,12 +3,11 @@ from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic.request import Request from sanic.response import BaseHTTPResponse from asyncio import CancelledError, sleep -from enum import Enum from sanic.compat import Header from sanic.exceptions import ( @@ -20,33 +19,16 @@ ) from sanic.headers import format_http1_response from sanic.helpers import has_message_body +from sanic.http.constants import Stage +from sanic.http.stream import Stream from sanic.log import access_logger, error_logger, logger from sanic.touchup import TouchUpMeta -class Stage(Enum): - """ - Enum for representing the stage of the request/response cycle - - | ``IDLE`` Waiting for request - | ``REQUEST`` Request headers being received - | ``HANDLER`` Headers done, handler running - | ``RESPONSE`` Response headers sent, body in progress - | ``FAILED`` Unrecoverable state (error while sending response) - | - """ - - IDLE = 0 # Waiting for request - REQUEST = 1 # Request headers being received - HANDLER = 3 # Headers done, handler running - RESPONSE = 4 # Response headers sent, body in progress - FAILED = 100 # Unrecoverable state (error while sending response) - - HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" -class Http(metaclass=TouchUpMeta): +class Http(Stream, metaclass=TouchUpMeta): """ Internal helper for managing the HTTP request/response cycle @@ -67,7 +49,6 @@ class Http(metaclass=TouchUpMeta): HEADER_CEILING = 16_384 HEADER_MAX_SIZE = 0 - __touchup__ = ( "http1_request_header", "http1_response_header", @@ -353,6 +334,12 @@ async def http1_response_header( self.response_func = self.head_response_ignored headers["connection"] = "keep-alive" if self.keep_alive else "close" + + # This header may be removed or modified by the AltSvcCheck Touchup + # service. At server start, we either remove this header from ever + # being assigned, or we change the value as required. + headers["alt-svc"] = "" + ret = format_http1_response(status, res.processed_headers) if data: ret += data diff --git a/sanic/http/http3.py b/sanic/http/http3.py new file mode 100644 index 0000000000..09ecad5b9f --- /dev/null +++ b/sanic/http/http3.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import asyncio + +from abc import ABC, abstractmethod +from ssl import SSLContext +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +from aioquic.h0.connection import H0_ALPN, H0Connection +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import ( + DatagramReceived, + DataReceived, + H3Event, + HeadersReceived, + WebTransportStreamDataReceived, +) +from aioquic.quic.configuration import QuicConfiguration +from aioquic.tls import SessionTicket + +from sanic.compat import Header +from sanic.constants import LocalCertCreator +from sanic.exceptions import PayloadTooLarge, SanicException, ServerError +from sanic.helpers import has_message_body +from sanic.http.constants import Stage +from sanic.http.stream import Stream +from sanic.http.tls.context import CertSelector, CertSimple, SanicSSLContext +from sanic.log import Colors, logger +from sanic.models.protocol_types import TransportProtocol +from sanic.models.server_types import ConnInfo + + +if TYPE_CHECKING: + from sanic import Sanic + from sanic.request import Request + from sanic.response import BaseHTTPResponse + from sanic.server.protocols.http_protocol import Http3Protocol + + +HttpConnection = Union[H0Connection, H3Connection] + + +class HTTP3Transport(TransportProtocol): + __slots__ = ("_protocol",) + + def __init__(self, protocol: Http3Protocol): + self._protocol = protocol + + def get_protocol(self) -> Http3Protocol: + return self._protocol + + def get_extra_info(self, info: str, default: Any = None) -> Any: + if ( + info in ("socket", "sockname", "peername") + and self._protocol._transport + ): + return self._protocol._transport.get_extra_info(info, default) + elif info == "network_paths": + return self._protocol._quic._network_paths + elif info == "ssl_context": + return self._protocol.app.state.ssl + return default + + +class Receiver(ABC): + future: asyncio.Future + + def __init__(self, transmit, protocol, request: Request) -> None: + self.transmit = transmit + self.protocol = protocol + self.request = request + + @abstractmethod + async def run(self): # no cov + ... + + +class HTTPReceiver(Receiver, Stream): + stage: Stage + request: Request + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.request_body = None + self.stage = Stage.IDLE + self.headers_sent = False + self.response: Optional[BaseHTTPResponse] = None + self.request_max_size = self.protocol.request_max_size + self.request_bytes = 0 + + async def run(self, exception: Optional[Exception] = None): + self.stage = Stage.HANDLER + self.head_only = self.request.method.upper() == "HEAD" + + if exception: + logger.info( # no cov + f"{Colors.BLUE}[exception]: " + f"{Colors.RED}{exception}{Colors.END}", + exc_info=True, + extra={"verbosity": 1}, + ) + await self.error_response(exception) + else: + try: + logger.info( # no cov + f"{Colors.BLUE}[request]:{Colors.END} {self.request}", + extra={"verbosity": 1}, + ) + await self.protocol.request_handler(self.request) + except Exception as e: # no cov + # This should largely be handled within the request handler. + # But, just in case... + await self.run(e) + self.stage = Stage.IDLE + + async def error_response(self, exception: Exception) -> None: + """ + Handle response when exception encountered + """ + # From request and handler states we can respond, otherwise be silent + app = self.protocol.app + + await app.handle_exception(self.request, exception) + + def _prepare_headers( + self, response: BaseHTTPResponse + ) -> List[Tuple[bytes, bytes]]: + size = len(response.body) if response.body else 0 + headers = response.headers + status = response.status + + if not has_message_body(status) and ( + size + or "content-length" in headers + or "transfer-encoding" in headers + ): + headers.pop("content-length", None) + headers.pop("transfer-encoding", None) + logger.warning( # no cov + f"Message body set in response on {self.request.path}. " + f"A {status} response may only have headers, no body." + ) + elif "content-length" not in headers: + if size: + headers["content-length"] = size + else: + headers["transfer-encoding"] = "chunked" + + headers = [ + (b":status", str(response.status).encode()), + *response.processed_headers, + ] + return headers + + def send_headers(self) -> None: + logger.debug( # no cov + f"{Colors.BLUE}[send]: {Colors.GREEN}HEADERS{Colors.END}", + extra={"verbosity": 2}, + ) + if not self.response: + raise RuntimeError("no response") + + response = self.response + headers = self._prepare_headers(response) + + self.protocol.connection.send_headers( + stream_id=self.request.stream_id, + headers=headers, + ) + self.headers_sent = True + self.stage = Stage.RESPONSE + + if self.response.body and not self.head_only: + self._send(self.response.body, False) + elif self.head_only: + self.future.cancel() + + def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse: + logger.debug( # no cov + f"{Colors.BLUE}[respond]:{Colors.END} {response}", + extra={"verbosity": 2}, + ) + + if self.stage is not Stage.HANDLER: + self.stage = Stage.FAILED + raise RuntimeError("Response already started") + + # Disconnect any earlier but unused response object + if self.response is not None: + self.response.stream = None + + self.response, response.stream = response, self + + return response + + def receive_body(self, data: bytes) -> None: + self.request_bytes += len(data) + if self.request_bytes > self.request_max_size: + raise PayloadTooLarge("Request body exceeds the size limit") + + self.request.body += data + + async def send(self, data: bytes, end_stream: bool) -> None: + logger.debug( # no cov + f"{Colors.BLUE}[send]: {Colors.GREEN}data={data.decode()} " + f"end_stream={end_stream}{Colors.END}", + extra={"verbosity": 2}, + ) + self._send(data, end_stream) + + def _send(self, data: bytes, end_stream: bool) -> None: + if not self.headers_sent: + self.send_headers() + if self.stage is not Stage.RESPONSE: + raise ServerError(f"not ready to send: {self.stage}") + + # Chunked + if ( + self.response + and self.response.headers.get("transfer-encoding") == "chunked" + ): + size = len(data) + if end_stream: + data = ( + b"%x\r\n%b\r\n0\r\n\r\n" % (size, data) + if size + else b"0\r\n\r\n" + ) + elif size: + data = b"%x\r\n%b\r\n" % (size, data) + + logger.debug( # no cov + f"{Colors.BLUE}[transmitting]{Colors.END}", + extra={"verbosity": 2}, + ) + self.protocol.connection.send_data( + stream_id=self.request.stream_id, + data=data, + end_stream=end_stream, + ) + self.transmit() + + if end_stream: + self.stage = Stage.IDLE + + +class WebsocketReceiver(Receiver): # noqa + async def run(self): + ... + + +class WebTransportReceiver(Receiver): # noqa + async def run(self): + ... + + +class Http3: + HANDLER_PROPERTY_MAPPING = { + DataReceived: "stream_id", + HeadersReceived: "stream_id", + DatagramReceived: "flow_id", + WebTransportStreamDataReceived: "session_id", + } + + def __init__( + self, + protocol: Http3Protocol, + transmit: Callable[[], None], + ) -> None: + self.protocol = protocol + self.transmit = transmit + self.receivers: Dict[int, Receiver] = {} + + def http_event_received(self, event: H3Event) -> None: + logger.debug( # no cov + f"{Colors.BLUE}[http_event_received]: " + f"{Colors.YELLOW}{event}{Colors.END}", + extra={"verbosity": 2}, + ) + receiver, created_new = self.get_or_make_receiver(event) + receiver = cast(HTTPReceiver, receiver) + + if isinstance(event, HeadersReceived) and created_new: + receiver.future = asyncio.ensure_future(receiver.run()) + elif isinstance(event, DataReceived): + try: + receiver.receive_body(event.data) + except Exception as e: + receiver.future.cancel() + receiver.future = asyncio.ensure_future(receiver.run(e)) + else: + ... # Intentionally here to help out Touchup + logger.debug( # no cov + f"{Colors.RED}DOING NOTHING{Colors.END}", + extra={"verbosity": 2}, + ) + + def get_or_make_receiver(self, event: H3Event) -> Tuple[Receiver, bool]: + if ( + isinstance(event, HeadersReceived) + and event.stream_id not in self.receivers + ): + request = self._make_request(event) + receiver = HTTPReceiver(self.transmit, self.protocol, request) + request.stream = receiver + + self.receivers[event.stream_id] = receiver + return receiver, True + else: + ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)]) + return self.receivers[ident], False + + def get_receiver_by_stream_id(self, stream_id: int) -> Receiver: + return self.receivers[stream_id] + + def _make_request(self, event: HeadersReceived) -> Request: + headers = Header(((k.decode(), v.decode()) for k, v in event.headers)) + method = headers[":method"] + path = headers[":path"] + scheme = headers.pop(":scheme", "") + authority = headers.pop(":authority", "") + + if authority: + headers["host"] = authority + + transport = HTTP3Transport(self.protocol) + request = self.protocol.request_class( + path.encode(), + headers, + "3", + method, + transport, + self.protocol.app, + b"", + ) + request.conn_info = ConnInfo(transport) + request._stream_id = event.stream_id + request._scheme = scheme + + return request + + +class SessionTicketStore: + """ + Simple in-memory store for session tickets. + """ + + def __init__(self) -> None: + self.tickets: Dict[bytes, SessionTicket] = {} + + def add(self, ticket: SessionTicket) -> None: + self.tickets[ticket.ticket] = ticket + + def pop(self, label: bytes) -> Optional[SessionTicket]: + return self.tickets.pop(label, None) + + +def get_config( + app: Sanic, ssl: Union[SanicSSLContext, CertSelector, SSLContext] +): + # TODO: + # - proper selection needed if servince with multiple certs insted of + # just taking the first + if isinstance(ssl, CertSelector): + ssl = cast(SanicSSLContext, ssl.sanic_select[0]) + if app.config.LOCAL_CERT_CREATOR is LocalCertCreator.TRUSTME: + raise SanicException( + "Sorry, you cannot currently use trustme as a local certificate " + "generator for an HTTP/3 server. This is not yet supported. You " + "should be able to use mkcert instead. For more information, see: " + "https://github.com/aiortc/aioquic/issues/295." + ) + if not isinstance(ssl, CertSimple): + raise SanicException("SSLContext is not CertSimple") + + config = QuicConfiguration( + alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], + is_client=False, + max_datagram_frame_size=65536, + ) + password = app.config.TLS_CERT_PASSWORD or None + + config.load_cert_chain( + ssl.sanic["cert"], ssl.sanic["key"], password=password + ) + + return config diff --git a/sanic/http/stream.py b/sanic/http/stream.py new file mode 100644 index 0000000000..9b413195bd --- /dev/null +++ b/sanic/http/stream.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple, Union + +from sanic.http.constants import Stage + + +if TYPE_CHECKING: + from sanic.response import BaseHTTPResponse + from sanic.server.protocols.http_protocol import HttpProtocol + + +class Stream: + stage: Stage + response: Optional[BaseHTTPResponse] + protocol: HttpProtocol + url: Optional[str] + request_body: Optional[bytes] + request_max_size: Union[int, float] + + __touchup__: Tuple[str, ...] = tuple() + __slots__ = () + + def respond( + self, response: BaseHTTPResponse + ) -> BaseHTTPResponse: # no cov + raise NotImplementedError("Not implemented") diff --git a/sanic/http/tls/__init__.py b/sanic/http/tls/__init__.py new file mode 100644 index 0000000000..b12fe529f8 --- /dev/null +++ b/sanic/http/tls/__init__.py @@ -0,0 +1,5 @@ +from .context import process_to_context +from .creators import get_ssl_context + + +__all__ = ("get_ssl_context", "process_to_context") diff --git a/sanic/tls.py b/sanic/http/tls/context.py similarity index 95% rename from sanic/tls.py rename to sanic/http/tls/context.py index be30f4a263..f77fa56051 100644 --- a/sanic/tls.py +++ b/sanic/http/tls/context.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import os import ssl -from typing import Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union from sanic.log import logger @@ -77,65 +79,6 @@ def load_cert_dir(p: str) -> ssl.SSLContext: return CertSimple(certfile, keyfile) -class CertSimple(ssl.SSLContext): - """A wrapper for creating SSLContext with a sanic attribute.""" - - def __new__(cls, cert, key, **kw): - # try common aliases, rename to cert/key - certfile = kw["cert"] = kw.pop("certificate", None) or cert - keyfile = kw["key"] = kw.pop("keyfile", None) or key - password = kw.pop("password", None) - if not certfile or not keyfile: - raise ValueError("SSL dict needs filenames for cert and key.") - subject = {} - if "names" not in kw: - cert = ssl._ssl._test_decode_cert(certfile) # type: ignore - kw["names"] = [ - name - for t, name in cert["subjectAltName"] - if t in ["DNS", "IP Address"] - ] - subject = {k: v for item in cert["subject"] for k, v in item} - self = create_context(certfile, keyfile, password) - self.__class__ = cls - self.sanic = {**subject, **kw} - return self - - def __init__(self, cert, key, **kw): - pass # Do not call super().__init__ because it is already initialized - - -class CertSelector(ssl.SSLContext): - """Automatically select SSL certificate based on the hostname that the - client is trying to access, via SSL SNI. Paths to certificate folders - with privkey.pem and fullchain.pem in them should be provided, and - will be matched in the order given whenever there is a new connection. - """ - - def __new__(cls, ctxs): - return super().__new__(cls) - - def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]): - super().__init__() - self.sni_callback = selector_sni_callback # type: ignore - self.sanic_select = [] - self.sanic_fallback = None - all_names = [] - for i, ctx in enumerate(ctxs): - if not ctx: - continue - names = dict(getattr(ctx, "sanic", {})).get("names", []) - all_names += names - self.sanic_select.append(ctx) - if i == 0: - self.sanic_fallback = ctx - if not all_names: - raise ValueError( - "No certificates with SubjectAlternativeNames found." - ) - logger.info(f"Certificate vhosts: {', '.join(all_names)}") - - def find_cert(self: CertSelector, server_name: str): """Find the first certificate that matches the given SNI. @@ -194,3 +137,73 @@ def server_name_callback( ) -> None: """Store the received SNI as sslobj.sanic_server_name.""" sslobj.sanic_server_name = server_name # type: ignore + + +class SanicSSLContext(ssl.SSLContext): + sanic: Dict[str, os.PathLike] + + @classmethod + def create_from_ssl_context(cls, context: ssl.SSLContext): + context.__class__ = cls + return context + + +class CertSimple(SanicSSLContext): + """A wrapper for creating SSLContext with a sanic attribute.""" + + sanic: Dict[str, Any] + + def __new__(cls, cert, key, **kw): + # try common aliases, rename to cert/key + certfile = kw["cert"] = kw.pop("certificate", None) or cert + keyfile = kw["key"] = kw.pop("keyfile", None) or key + password = kw.pop("password", None) + if not certfile or not keyfile: + raise ValueError("SSL dict needs filenames for cert and key.") + subject = {} + if "names" not in kw: + cert = ssl._ssl._test_decode_cert(certfile) # type: ignore + kw["names"] = [ + name + for t, name in cert["subjectAltName"] + if t in ["DNS", "IP Address"] + ] + subject = {k: v for item in cert["subject"] for k, v in item} + self = create_context(certfile, keyfile, password) + self.__class__ = cls + self.sanic = {**subject, **kw} + return self + + def __init__(self, cert, key, **kw): + pass # Do not call super().__init__ because it is already initialized + + +class CertSelector(ssl.SSLContext): + """Automatically select SSL certificate based on the hostname that the + client is trying to access, via SSL SNI. Paths to certificate folders + with privkey.pem and fullchain.pem in them should be provided, and + will be matched in the order given whenever there is a new connection. + """ + + def __new__(cls, ctxs): + return super().__new__(cls) + + def __init__(self, ctxs: Iterable[Optional[ssl.SSLContext]]): + super().__init__() + self.sni_callback = selector_sni_callback # type: ignore + self.sanic_select = [] + self.sanic_fallback = None + all_names = [] + for i, ctx in enumerate(ctxs): + if not ctx: + continue + names = dict(getattr(ctx, "sanic", {})).get("names", []) + all_names += names + self.sanic_select.append(ctx) + if i == 0: + self.sanic_fallback = ctx + if not all_names: + raise ValueError( + "No certificates with SubjectAlternativeNames found." + ) + logger.info(f"Certificate vhosts: {', '.join(all_names)}") diff --git a/sanic/http/tls/creators.py b/sanic/http/tls/creators.py new file mode 100644 index 0000000000..2043cfd681 --- /dev/null +++ b/sanic/http/tls/creators.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +import ssl +import subprocess +import sys + +from abc import ABC, abstractmethod +from contextlib import suppress +from pathlib import Path +from tempfile import mkdtemp +from types import ModuleType +from typing import TYPE_CHECKING, Optional, Tuple, Type, Union, cast + +from sanic.application.constants import Mode +from sanic.application.spinner import loading +from sanic.constants import ( + DEFAULT_LOCAL_TLS_CERT, + DEFAULT_LOCAL_TLS_KEY, + LocalCertCreator, +) +from sanic.exceptions import SanicException +from sanic.helpers import Default +from sanic.http.tls.context import CertSimple, SanicSSLContext + + +try: + import trustme + + TRUSTME_INSTALLED = True +except (ImportError, ModuleNotFoundError): + trustme = ModuleType("trustme") + TRUSTME_INSTALLED = False + +if TYPE_CHECKING: + from sanic import Sanic + + +# Only allow secure ciphers, notably leaving out AES-CBC mode +# OpenSSL chooses ECDSA or RSA depending on the cert in use +CIPHERS_TLS12 = [ + "ECDHE-ECDSA-CHACHA20-POLY1305", + "ECDHE-ECDSA-AES256-GCM-SHA384", + "ECDHE-ECDSA-AES128-GCM-SHA256", + "ECDHE-RSA-CHACHA20-POLY1305", + "ECDHE-RSA-AES256-GCM-SHA384", + "ECDHE-RSA-AES128-GCM-SHA256", +] + + +def _make_path(maybe_path: Union[Path, str], tmpdir: Optional[Path]) -> Path: + if isinstance(maybe_path, Path): + return maybe_path + else: + path = Path(maybe_path) + if not path.exists(): + if not tmpdir: + raise RuntimeError("Reached an unknown state. No tmpdir.") + return tmpdir / maybe_path + + return path + + +def get_ssl_context( + app: Sanic, ssl: Optional[ssl.SSLContext] +) -> ssl.SSLContext: + if ssl: + return ssl + + if app.state.mode is Mode.PRODUCTION: + raise SanicException( + "Cannot run Sanic as an HTTPS server in PRODUCTION mode " + "without passing a TLS certificate. If you are developing " + "locally, please enable DEVELOPMENT mode and Sanic will " + "generate a localhost TLS certificate. For more information " + "please see: ___." + ) + + creator = CertCreator.select( + app, + cast(LocalCertCreator, app.config.LOCAL_CERT_CREATOR), + app.config.LOCAL_TLS_KEY, + app.config.LOCAL_TLS_CERT, + ) + context = creator.generate_cert(app.config.LOCALHOST) + return context + + +class CertCreator(ABC): + def __init__(self, app, key, cert) -> None: + self.app = app + self.key = key + self.cert = cert + self.tmpdir = None + + if isinstance(self.key, Default) or isinstance(self.cert, Default): + self.tmpdir = Path(mkdtemp()) + + key = ( + DEFAULT_LOCAL_TLS_KEY + if isinstance(self.key, Default) + else self.key + ) + cert = ( + DEFAULT_LOCAL_TLS_CERT + if isinstance(self.cert, Default) + else self.cert + ) + + self.key_path = _make_path(key, self.tmpdir) + self.cert_path = _make_path(cert, self.tmpdir) + + @abstractmethod + def check_supported(self) -> None: # no cov + ... + + @abstractmethod + def generate_cert(self, localhost: str) -> ssl.SSLContext: # no cov + ... + + @classmethod + def select( + cls, + app: Sanic, + cert_creator: LocalCertCreator, + local_tls_key, + local_tls_cert, + ) -> CertCreator: + + creator: Optional[CertCreator] = None + + cert_creator_options: Tuple[ + Tuple[Type[CertCreator], LocalCertCreator], ... + ] = ( + (MkcertCreator, LocalCertCreator.MKCERT), + (TrustmeCreator, LocalCertCreator.TRUSTME), + ) + for creator_class, local_creator in cert_creator_options: + creator = cls._try_select( + app, + creator, + creator_class, + local_creator, + cert_creator, + local_tls_key, + local_tls_cert, + ) + if creator: + break + + if not creator: + raise SanicException( + "Sanic could not find package to create a TLS certificate. " + "You must have either mkcert or trustme installed. See " + "_____ for more details." + ) + + return creator + + @staticmethod + def _try_select( + app: Sanic, + creator: Optional[CertCreator], + creator_class: Type[CertCreator], + creator_requirement: LocalCertCreator, + creator_requested: LocalCertCreator, + local_tls_key, + local_tls_cert, + ): + if creator or ( + creator_requested is not LocalCertCreator.AUTO + and creator_requested is not creator_requirement + ): + return creator + + instance = creator_class(app, local_tls_key, local_tls_cert) + try: + instance.check_supported() + except SanicException: + if creator_requested is creator_requirement: + raise + else: + return None + + return instance + + +class MkcertCreator(CertCreator): + def check_supported(self) -> None: + try: + subprocess.run( # nosec B603 B607 + ["mkcert", "-help"], + check=True, + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + ) + except Exception as e: + raise SanicException( + "Sanic is attempting to use mkcert to generate local TLS " + "certificates since you did not supply a certificate, but " + "one is required. Sanic cannot proceed since mkcert does not " + "appear to be installed. Alternatively, you can use trustme. " + "Please install mkcert, trustme, or supply TLS certificates " + "to proceed. Installation instructions can be found here: " + "https://github.com/FiloSottile/mkcert.\n" + "Find out more information about your options here: " + "_____" + ) from e + + def generate_cert(self, localhost: str) -> ssl.SSLContext: + try: + if not self.cert_path.exists(): + message = "Generating TLS certificate" + # TODO: Validate input for security + with loading(message): + cmd = [ + "mkcert", + "-key-file", + str(self.key_path), + "-cert-file", + str(self.cert_path), + localhost, + ] + resp = subprocess.run( # nosec B603 + cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + sys.stdout.write("\r" + " " * (len(message) + 4)) + sys.stdout.flush() + sys.stdout.write(resp.stdout) + finally: + + @self.app.main_process_stop + async def cleanup(*_): # no cov + if self.tmpdir: + with suppress(FileNotFoundError): + self.key_path.unlink() + self.cert_path.unlink() + self.tmpdir.rmdir() + + return CertSimple(self.cert_path, self.key_path) + + +class TrustmeCreator(CertCreator): + def check_supported(self) -> None: + if not TRUSTME_INSTALLED: + raise SanicException( + "Sanic is attempting to use trustme to generate local TLS " + "certificates since you did not supply a certificate, but " + "one is required. Sanic cannot proceed since trustme does not " + "appear to be installed. Alternatively, you can use mkcert. " + "Please install mkcert, trustme, or supply TLS certificates " + "to proceed. Installation instructions can be found here: " + "https://github.com/python-trio/trustme.\n" + "Find out more information about your options here: " + "_____" + ) + + def generate_cert(self, localhost: str) -> ssl.SSLContext: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sanic_context = SanicSSLContext.create_from_ssl_context(context) + sanic_context.sanic = { + "cert": self.cert_path.absolute(), + "key": self.key_path.absolute(), + } + ca = trustme.CA() + server_cert = ca.issue_cert(localhost) + server_cert.configure_cert(sanic_context) + ca.configure_trust(context) + + ca.cert_pem.write_to_path(str(self.cert_path.absolute())) + server_cert.private_key_and_cert_chain_pem.write_to_path( + str(self.key_path.absolute()) + ) + + return context diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index ca390abe8e..5704c600db 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -21,7 +21,7 @@ ) from urllib.parse import unquote -from sanic_routing.route import Route # type: ignore +from sanic_routing.route import Route from sanic.base.meta import SanicMeta from sanic.compat import stat_async diff --git a/sanic/mixins/runner.py b/sanic/mixins/runner.py index d2d1e66ceb..ee787776e6 100644 --- a/sanic/mixins/runner.py +++ b/sanic/mixins/runner.py @@ -2,6 +2,7 @@ import os import platform +import sys from asyncio import ( AbstractEventLoop, @@ -18,7 +19,18 @@ from pathlib import Path from socket import socket from ssl import SSLContext -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) from sanic import reloader_helpers from sanic.application.logo import get_logo @@ -27,7 +39,9 @@ from sanic.base.meta import SanicMeta from sanic.compat import OS_IS_WINDOWS, is_atty from sanic.helpers import _default -from sanic.log import Colors, error_logger, logger +from sanic.http.constants import HTTP +from sanic.http.tls import get_ssl_context, process_to_context +from sanic.log import Colors, deprecation, error_logger, logger from sanic.models.handler_types import ListenerType from sanic.server import Signal as ServerSignal from sanic.server import try_use_uvloop @@ -36,16 +50,22 @@ from sanic.server.protocols.http_protocol import HttpProtocol from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.runners import serve, serve_multiple, serve_single -from sanic.tls import process_to_context -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic from sanic.application.state import ApplicationState from sanic.config import Config SANIC_PACKAGES = ("sanic-routing", "sanic-testing", "sanic-ext") +if sys.version_info < (3, 8): + HTTPVersion = Union[HTTP, int] +else: + from typing import Literal + + HTTPVersion = Union[HTTP, Literal[1], Literal[3]] + class RunnerMixin(metaclass=SanicMeta): _app_registry: Dict[str, Sanic] @@ -66,6 +86,7 @@ def run( dev: bool = False, debug: bool = False, auto_reload: Optional[bool] = None, + version: HTTPVersion = HTTP.VERSION_1, ssl: Union[None, SSLContext, dict, str, list, tuple] = None, sock: Optional[socket] = None, workers: int = 1, @@ -81,6 +102,7 @@ def run( fast: bool = False, verbosity: int = 0, motd_display: Optional[Dict[str, str]] = None, + auto_tls: bool = False, ) -> None: """ Run the HTTP Server and listen until keyboard interrupt or term @@ -124,6 +146,7 @@ def run( dev=dev, debug=debug, auto_reload=auto_reload, + version=version, ssl=ssl, sock=sock, workers=workers, @@ -139,6 +162,7 @@ def run( fast=fast, verbosity=verbosity, motd_display=motd_display, + auto_tls=auto_tls, ) self.__class__.serve(primary=self) # type: ignore @@ -151,6 +175,7 @@ def prepare( dev: bool = False, debug: bool = False, auto_reload: Optional[bool] = None, + version: HTTPVersion = HTTP.VERSION_1, ssl: Union[None, SSLContext, dict, str, list, tuple] = None, sock: Optional[socket] = None, workers: int = 1, @@ -166,7 +191,15 @@ def prepare( fast: bool = False, verbosity: int = 0, motd_display: Optional[Dict[str, str]] = None, + auto_tls: bool = False, ) -> None: + if version == 3 and self.state.server_info: + raise RuntimeError( + "Serving HTTP/3 instances as a secondary server is " + "not supported. There can only be a single HTTP/3 worker " + "and it must be the first instance prepared." + ) + if dev: debug = True auto_reload = True @@ -208,7 +241,7 @@ def prepare( return if sock is None: - host, port = host or "127.0.0.1", port or 8000 + host, port = self.get_address(host, port, version, auto_tls) if protocol is None: protocol = ( @@ -236,6 +269,7 @@ def prepare( host=host, port=port, debug=debug, + version=version, ssl=ssl, sock=sock, unix=unix, @@ -243,6 +277,7 @@ def prepare( protocol=protocol, backlog=backlog, register_sys_signals=register_sys_signals, + auto_tls=auto_tls, ) self.state.server_info.append( ApplicationServerInfo(settings=server_settings) @@ -312,7 +347,7 @@ async def create_server( """ if sock is None: - host, port = host or "127.0.0.1", port or 8000 + host, port = host, port = self.get_address(host, port) if protocol is None: protocol = ( @@ -377,6 +412,7 @@ def _helper( host: Optional[str] = None, port: Optional[int] = None, debug: bool = False, + version: HTTPVersion = HTTP.VERSION_1, ssl: Union[None, SSLContext, dict, str, list, tuple] = None, sock: Optional[socket] = None, unix: Optional[str] = None, @@ -386,6 +422,7 @@ def _helper( backlog: int = 100, register_sys_signals: bool = True, run_async: bool = False, + auto_tls: bool = False, ) -> Dict[str, Any]: """Helper function used by `run` and `create_server`.""" if self.config.PROXIES_COUNT and self.config.PROXIES_COUNT < 0: @@ -395,11 +432,18 @@ def _helper( "#proxy-configuration" ) - ssl = process_to_context(ssl) - if not self.state.is_debug: self.state.mode = Mode.DEBUG if debug else Mode.PRODUCTION + if isinstance(version, int): + version = HTTP(version) + + ssl = process_to_context(ssl) + if version is HTTP.VERSION_3 or auto_tls: + if TYPE_CHECKING: + self = cast(Sanic, self) + ssl = get_ssl_context(self, ssl) + self.state.host = host or "" self.state.port = port or 0 self.state.workers = workers @@ -411,6 +455,7 @@ def _helper( "protocol": protocol, "host": host, "port": port, + "version": version, "sock": sock, "unix": unix, "ssl": ssl, @@ -421,7 +466,7 @@ def _helper( "backlog": backlog, } - self.motd(self.serve_location) + self.motd(server_settings=server_settings) if is_atty() and not self.state.is_debug: error_logger.warning( @@ -447,7 +492,19 @@ def _helper( return server_settings - def motd(self, serve_location): + def motd( + self, + serve_location: str = "", + server_settings: Optional[Dict[str, Any]] = None, + ): + if serve_location: + deprecation( + "Specifying a serve_location in the MOTD is deprecated and " + "will be removed.", + 22.9, + ) + else: + serve_location = self.get_server_location(server_settings) if self.config.MOTD: mode = [f"{self.state.mode},"] if self.state.fast: @@ -460,9 +517,19 @@ def motd(self, serve_location): else: mode.append(f"w/ {self.state.workers} workers") + if server_settings: + server = ", ".join( + ( + self.state.server, + server_settings["version"].display(), # type: ignore + ) + ) + else: + server = "" + display = { "mode": " ".join(mode), - "server": self.state.server, + "server": server, "python": platform.python_version(), "platform": platform.platform(), } @@ -486,7 +553,9 @@ def motd(self, serve_location): module_name = package_name.replace("-", "_") try: module = import_module(module_name) - packages.append(f"{package_name}=={module.__version__}") + packages.append( + f"{package_name}=={module.__version__}" # type: ignore + ) except ImportError: ... @@ -506,25 +575,50 @@ def motd(self, serve_location): @property def serve_location(self) -> str: + server_settings = self.state.server_info[0].settings + return self.get_server_location(server_settings) + + @staticmethod + def get_server_location( + server_settings: Optional[Dict[str, Any]] = None + ) -> str: serve_location = "" proto = "http" - if self.state.ssl is not None: + if not server_settings: + return serve_location + + if server_settings["ssl"] is not None: proto = "https" - if self.state.unix: - serve_location = f"{self.state.unix} {proto}://..." - elif self.state.sock: - serve_location = f"{self.state.sock.getsockname()} {proto}://..." - elif self.state.host and self.state.port: + if server_settings["unix"]: + serve_location = f'{server_settings["unix"]} {proto}://...' + elif server_settings["sock"]: + serve_location = ( + f'{server_settings["sock"].getsockname()} {proto}://...' + ) + elif server_settings["host"] and server_settings["port"]: # colon(:) is legal for a host only in an ipv6 address display_host = ( - f"[{self.state.host}]" - if ":" in self.state.host - else self.state.host + f'[{server_settings["host"]}]' + if ":" in server_settings["host"] + else server_settings["host"] + ) + serve_location = ( + f'{proto}://{display_host}:{server_settings["port"]}' ) - serve_location = f"{proto}://{display_host}:{self.state.port}" return serve_location + @staticmethod + def get_address( + host: Optional[str], + port: Optional[int], + version: HTTPVersion = HTTP.VERSION_1, + auto_tls: bool = False, + ) -> Tuple[str, int]: + host = host or "127.0.0.1" + port = port or (8443 if (version == 3 or auto_tls) else 8000) + return host, port + @classmethod def should_auto_reload(cls) -> bool: return any(app.state.auto_reload for app in cls._app_registry.values()) diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 2b0ee0ed47..df6ab3d2ee 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -4,6 +4,7 @@ from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from sanic.exceptions import BadRequest +from sanic.models.protocol_types import TransportProtocol from sanic.server.websockets.connection import WebSocketConnection @@ -56,7 +57,7 @@ async def drain(self) -> None: await self._not_paused.wait() -class MockTransport: # no cov +class MockTransport(TransportProtocol): # no cov _protocol: Optional[MockProtocol] def __init__( @@ -68,17 +69,19 @@ def __init__( self._protocol = None self.loop = None - def get_protocol(self) -> MockProtocol: + def get_protocol(self) -> MockProtocol: # type: ignore if not self._protocol: self._protocol = MockProtocol(self, self.loop) return self._protocol - def get_extra_info(self, info: str) -> Union[str, bool, None]: + def get_extra_info( + self, info: str, default=None + ) -> Optional[Union[str, bool]]: if info == "peername": return self.scope.get("client") elif info == "sslcontext": return self.scope.get("scheme") in ["https", "wss"] - return None + return default def get_websocket_connection(self) -> WebSocketConnection: try: diff --git a/sanic/models/protocol_types.py b/sanic/models/protocol_types.py index 14bc275cbf..24b4361dae 100644 --- a/sanic/models/protocol_types.py +++ b/sanic/models/protocol_types.py @@ -1,32 +1,22 @@ -import sys +from __future__ import annotations -from typing import Any, AnyStr, TypeVar, Union +import sys -from sanic.models.asgi import ASGIScope +from asyncio import BaseTransport +from typing import TYPE_CHECKING, Any, AnyStr -if sys.version_info < (3, 8): - from asyncio import BaseTransport +if TYPE_CHECKING: + from sanic.models.asgi import ASGIScope - # from sanic.models.asgi import MockTransport - MockTransport = TypeVar("MockTransport") - TransportProtocol = Union[MockTransport, BaseTransport] +if sys.version_info < (3, 8): Range = Any HTMLProtocol = Any else: # Protocol is a 3.8+ feature from typing import Protocol - class TransportProtocol(Protocol): - scope: ASGIScope - - def get_protocol(self): - ... - - def get_extra_info(self, info: str) -> Union[str, bool, None]: - ... - class HTMLProtocol(Protocol): def __html__(self) -> AnyStr: ... @@ -46,3 +36,8 @@ def size(self) -> int: def total(self) -> int: ... + + +class TransportProtocol(BaseTransport): + scope: ASGIScope + __slots__ = () diff --git a/sanic/models/server_types.py b/sanic/models/server_types.py index ba9f2918d9..da88a8ff95 100644 --- a/sanic/models/server_types.py +++ b/sanic/models/server_types.py @@ -1,8 +1,8 @@ from __future__ import annotations -from ssl import SSLObject +from ssl import SSLContext, SSLObject from types import SimpleNamespace -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from sanic.models.protocol_types import TransportProtocol @@ -28,6 +28,7 @@ class ConnInfo: "sockname", "ssl", "cert", + "network_paths", ) def __init__(self, transport: TransportProtocol, unix=None): @@ -40,17 +41,22 @@ def __init__(self, transport: TransportProtocol, unix=None): self.ssl = False self.server_name = "" self.cert: Dict[str, Any] = {} + self.network_paths: List[Any] = [] sslobj: Optional[SSLObject] = transport.get_extra_info( "ssl_object" ) # type: ignore + sslctx: Optional[SSLContext] = transport.get_extra_info( + "ssl_context" + ) # type: ignore if sslobj: self.ssl = True self.server_name = getattr(sslobj, "sanic_server_name", None) or "" self.cert = dict(getattr(sslobj.context, "sanic", {})) + if sslctx and not self.cert: + self.cert = dict(getattr(sslctx, "sanic", {})) if isinstance(addr, str): # UNIX socket self.server = unix or addr return - # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) if isinstance(addr, tuple): self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" @@ -59,6 +65,9 @@ def __init__(self, transport: TransportProtocol, unix=None): if addr[1] != (443 if self.ssl else 80): self.server = f"{self.server}:{addr[1]}" self.peername = addr = transport.get_extra_info("peername") + self.network_paths = transport.get_extra_info( # type: ignore + "network_paths" + ) if isinstance(addr, tuple): self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" diff --git a/sanic/request.py b/sanic/request.py index f55283c3e7..c3ef504814 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextvars import ContextVar +from inspect import isawaitable from typing import ( TYPE_CHECKING, Any, @@ -13,13 +14,15 @@ Union, ) -from sanic_routing.route import Route # type: ignore +from sanic_routing.route import Route +from sanic.http.constants import HTTP # type: ignore +from sanic.http.stream import Stream from sanic.models.asgi import ASGIScope from sanic.models.http_types import Credentials -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic.server import ConnInfo from sanic.app import Sanic @@ -47,7 +50,7 @@ parse_host, parse_xforwarded, ) -from sanic.http import Http, Stage +from sanic.http import Stage from sanic.log import error_logger, logger from sanic.models.protocol_types import TransportProtocol from sanic.response import BaseHTTPResponse, HTTPResponse @@ -94,7 +97,9 @@ class Request: "_port", "_protocol", "_remote_addr", + "_scheme", "_socket", + "_stream_id", "_match_info", "_name", "app", @@ -131,6 +136,7 @@ def __init__( transport: TransportProtocol, app: Sanic, head: bytes = b"", + stream_id: int = 0, ): self.raw_url = url_bytes @@ -140,6 +146,7 @@ def __init__( raise BadURL(f"Bad URL: {url_bytes.decode()}") self._id: Optional[Union[uuid.UUID, str, int]] = None self._name: Optional[str] = None + self._stream_id = stream_id self.app = app self.headers = Header(headers) @@ -166,12 +173,12 @@ def __init__( Tuple[bool, bool, str, str], List[Tuple[str, str]] ] = defaultdict(list) self.request_middleware_started = False + self.responded: bool = False + self.route: Optional[Route] = None + self.stream: Optional[Stream] = None self._cookies: Optional[Dict[str, str]] = None self._match_info: Dict[str, Any] = {} - self.stream: Optional[Http] = None - self.route: Optional[Route] = None self._protocol = None - self.responded: bool = False def __repr__(self): class_name = self.__class__.__name__ @@ -188,6 +195,14 @@ def get_current(cls) -> Request: def generate_id(*_): return uuid.uuid4() + @property + def stream_id(self): + if self.protocol.version is not HTTP.VERSION_3: + raise ServerError( + "Stream ID is only a property of a HTTP/3 request" + ) + return self._stream_id + def reset_response(self): try: if ( @@ -274,6 +289,9 @@ async def add_header(_, response: HTTPResponse): # Connect the response if isinstance(response, BaseHTTPResponse) and self.stream: response = self.stream.respond(response) + + if isawaitable(response): + response = await response # type: ignore # Run response middleware try: response = await self.app._run_response_middleware( @@ -668,6 +686,10 @@ def path(self) -> str: """ return self._parsed_url.path.decode("utf-8") + @property + def network_paths(self): + return self.conn_info.network_paths + # Proxy properties (using SERVER_NAME/forwarded/request/transport info) @property @@ -721,23 +743,25 @@ def scheme(self) -> str: :return: http|https|ws|wss or arbitrary value given by the headers. :rtype: str """ - if "//" in self.app.config.get("SERVER_NAME", ""): - return self.app.config.SERVER_NAME.split("//")[0] - if "proto" in self.forwarded: - return str(self.forwarded["proto"]) + if not hasattr(self, "_scheme"): + if "//" in self.app.config.get("SERVER_NAME", ""): + return self.app.config.SERVER_NAME.split("//")[0] + if "proto" in self.forwarded: + return str(self.forwarded["proto"]) - if ( - self.app.websocket_enabled - and self.headers.getone("upgrade", "").lower() == "websocket" - ): - scheme = "ws" - else: - scheme = "http" + if ( + self.app.websocket_enabled + and self.headers.getone("upgrade", "").lower() == "websocket" + ): + scheme = "ws" + else: + scheme = "http" - if self.transport.get_extra_info("sslcontext"): - scheme += "s" + if self.transport.get_extra_info("sslcontext"): + scheme += "s" + self._scheme = scheme - return scheme + return self._scheme @property def host(self) -> str: diff --git a/sanic/response.py b/sanic/response.py index b1765ed141..adb7a5b615 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from sanic.asgi import ASGIApp + from sanic.http.http3 import HTTPReceiver from sanic.request import Request else: Request = TypeVar("Request") @@ -74,11 +75,15 @@ def __init__(self): self.asgi: bool = False self.body: Optional[bytes] = None self.content_type: Optional[str] = None - self.stream: Optional[Union[Http, ASGIApp]] = None + self.stream: Optional[Union[Http, ASGIApp, HTTPReceiver]] = None self.status: int = None self.headers = Header({}) self._cookies: Optional[CookieJar] = None + def __repr__(self): + class_name = self.__class__.__name__ + return f"<{class_name}: {self.status} {self.content_type}>" + def _encode_body(self, data: Optional[AnyStr]): if data is None: return b"" @@ -157,7 +162,10 @@ async def send( if hasattr(data, "encode") else data or b"" ) - await self.stream.send(data, end_stream=end_stream) + await self.stream.send( + data, # type: ignore + end_stream=end_stream or False, + ) class HTTPResponse(BaseHTTPResponse): diff --git a/sanic/router.py b/sanic/router.py index 01c268b9a9..ec4d852f5d 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -5,12 +5,10 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import UUID -from sanic_routing import BaseRouter # type: ignore -from sanic_routing.exceptions import NoMethod # type: ignore -from sanic_routing.exceptions import ( - NotFound as RoutingNotFound, # type: ignore -) -from sanic_routing.route import Route # type: ignore +from sanic_routing import BaseRouter +from sanic_routing.exceptions import NoMethod +from sanic_routing.exceptions import NotFound as RoutingNotFound +from sanic_routing.route import Route from sanic.constants import HTTP_METHODS from sanic.errorpages import check_error_format diff --git a/sanic/server/events.py b/sanic/server/events.py index 41a89aea1d..ae93c78e9c 100644 --- a/sanic/server/events.py +++ b/sanic/server/events.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic diff --git a/sanic/server/protocols/base_protocol.py b/sanic/server/protocols/base_protocol.py index 3a2716698f..63d4bfb5b7 100644 --- a/sanic/server/protocols/base_protocol.py +++ b/sanic/server/protocols/base_protocol.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic.app import Sanic import asyncio diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py index d8cfc51122..b3d7625b06 100644 --- a/sanic/server/protocols/http_protocol.py +++ b/sanic/server/protocols/http_protocol.py @@ -2,10 +2,14 @@ from typing import TYPE_CHECKING, Optional +from aioquic.h3.connection import H3_ALPN, H3Connection + +from sanic.http.constants import HTTP +from sanic.http.http3 import Http3 from sanic.touchup.meta import TouchUpMeta -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic.app import Sanic import sys @@ -13,24 +17,68 @@ from asyncio import CancelledError from time import monotonic as current_time +from aioquic.asyncio import QuicConnectionProtocol +from aioquic.quic.events import ( + DatagramFrameReceived, + ProtocolNegotiated, + QuicEvent, +) + from sanic.exceptions import RequestTimeout, ServiceUnavailable from sanic.http import Http, Stage -from sanic.log import error_logger, logger +from sanic.log import Colors, error_logger, logger from sanic.models.server_types import ConnInfo from sanic.request import Request from sanic.server.protocols.base_protocol import SanicProtocol -class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): +class HttpProtocolMixin: + __slots__ = () + __version__: HTTP + + def _setup_connection(self, *args, **kwargs): + self._http = self.HTTP_CLASS(self, *args, **kwargs) + self._time = current_time() + try: + self.check_timeouts() + except AttributeError: + ... + + def _setup(self): + self.request: Optional[Request] = None + self.access_log = self.app.config.ACCESS_LOG + self.request_handler = self.app.handle_request + self.error_handler = self.app.error_handler + self.request_timeout = self.app.config.REQUEST_TIMEOUT + self.response_timeout = self.app.config.RESPONSE_TIMEOUT + self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT + self.request_max_size = self.app.config.REQUEST_MAX_SIZE + self.request_class = self.app.request_class or Request + + @property + def http(self): + if not hasattr(self, "_http"): + return None + return self._http + + @property + def version(self): + return self.__class__.__version__ + + +class HttpProtocol(HttpProtocolMixin, SanicProtocol, metaclass=TouchUpMeta): """ This class provides implements the HTTP 1.1 protocol on top of our Sanic Server transport """ + HTTP_CLASS = Http + __touchup__ = ( "send", "connection_task", ) + __version__ = HTTP.VERSION_1 __slots__ = ( # request params "request", @@ -72,25 +120,12 @@ def __init__( unix=unix, ) self.url = None - self.request: Optional[Request] = None - self.access_log = self.app.config.ACCESS_LOG - self.request_handler = self.app.handle_request - self.error_handler = self.app.error_handler - self.request_timeout = self.app.config.REQUEST_TIMEOUT - self.response_timeout = self.app.config.RESPONSE_TIMEOUT - self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT - self.request_max_size = self.app.config.REQUEST_MAX_SIZE - self.request_class = self.app.request_class or Request self.state = state if state else {} + self._setup() if "requests_count" not in self.state: self.state["requests_count"] = 0 self._exception = None - def _setup_connection(self): - self._http = Http(self) - self._time = current_time() - self.check_timeouts() - async def connection_task(self): # no cov """ Run a HTTP connection. @@ -241,3 +276,39 @@ def data_received(self, data: bytes): self._data_received.set() except Exception: error_logger.exception("protocol.data_received") + + +class Http3Protocol(HttpProtocolMixin, QuicConnectionProtocol): + HTTP_CLASS = Http3 + __version__ = HTTP.VERSION_3 + + def __init__(self, *args, app: Sanic, **kwargs) -> None: + self.app = app + super().__init__(*args, **kwargs) + self._setup() + self._connection: Optional[H3Connection] = None + + def quic_event_received(self, event: QuicEvent) -> None: + logger.debug( + f"{Colors.BLUE}[quic_event_received]: " + f"{Colors.PURPLE}{event}{Colors.END}", + extra={"verbosity": 2}, + ) + if isinstance(event, ProtocolNegotiated): + self._setup_connection(transmit=self.transmit) + if event.alpn_protocol in H3_ALPN: + self._connection = H3Connection( + self._quic, enable_webtransport=True + ) + elif isinstance(event, DatagramFrameReceived): + if event.data == b"quack": + self._quic.send_datagram_frame(b"quack-ack") + + # pass event to the HTTP layer + if self._connection is not None: + for http_event in self._connection.handle_event(event): + self._http.http_event_received(http_event) + + @property + def connection(self) -> Optional[H3Connection]: + return self._connection diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 866f52460f..e55ebdd610 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -11,7 +11,7 @@ from ..websockets.impl import WebsocketImplProtocol -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from websockets import http11 diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 53fb3cfe6f..81c8b64a44 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Type, Union from sanic.config import Config +from sanic.http.constants import HTTP +from sanic.http.tls import get_ssl_context from sanic.server.events import trigger_events @@ -21,12 +23,15 @@ from signal import SIG_IGN, SIGINT, SIGTERM, Signals from signal import signal as signal_func +from aioquic.asyncio import serve as quic_serve + from sanic.application.ext import setup_ext from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows +from sanic.http.http3 import SessionTicketStore, get_config from sanic.log import error_logger, logger from sanic.models.server_types import Signal from sanic.server.async_server import AsyncioServer -from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.protocols.http_protocol import Http3Protocol, HttpProtocol from sanic.server.socket import ( bind_socket, bind_unix_socket, @@ -52,6 +57,7 @@ def serve( signal=Signal(), state=None, asyncio_server_kwargs=None, + version=HTTP.VERSION_1, ): """Start asynchronous HTTP Server on an individual process. @@ -88,6 +94,87 @@ def serve( app.asgi = False + if version is HTTP.VERSION_3: + return _serve_http_3(host, port, app, loop, ssl) + return _serve_http_1( + host, + port, + app, + ssl, + sock, + unix, + reuse_port, + loop, + protocol, + backlog, + register_sys_signals, + run_multiple, + run_async, + connections, + signal, + state, + asyncio_server_kwargs, + ) + + +def _setup_system_signals( + app: Sanic, + run_multiple: bool, + register_sys_signals: bool, + loop: asyncio.AbstractEventLoop, +) -> None: + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + os.environ["SANIC_WORKER_PROCESS"] = "true" + + # Register signals for graceful termination + if register_sys_signals: + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) + + +def _run_server_forever(loop, before_stop, after_stop, cleanup, unix): + pid = os.getpid() + try: + logger.info("Starting worker [%s]", pid) + loop.run_forever() + except KeyboardInterrupt: + pass + finally: + logger.info("Stopping worker [%s]", pid) + + loop.run_until_complete(before_stop()) + + if cleanup: + cleanup() + + loop.run_until_complete(after_stop()) + remove_unix_socket(unix) + + +def _serve_http_1( + host, + port, + app, + ssl, + sock, + unix, + reuse_port, + loop, + protocol, + backlog, + register_sys_signals, + run_multiple, + run_async, + connections, + signal, + state, + asyncio_server_kwargs, +): connections = connections if connections is not None else set() protocol_kwargs = _build_protocol_kwargs(protocol, app.config) server = partial( @@ -135,30 +222,7 @@ def serve( error_logger.exception("Unable to start server", exc_info=True) return - # Ignore SIGINT when run_multiple - if run_multiple: - signal_func(SIGINT, SIG_IGN) - os.environ["SANIC_WORKER_PROCESS"] = "true" - - # Register signals for graceful termination - if register_sys_signals: - if OS_IS_WINDOWS: - ctrlc_workaround_for_windows(app) - else: - for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: - loop.add_signal_handler(_signal, app.stop) - - loop.run_until_complete(app._server_event("init", "after")) - pid = os.getpid() - try: - logger.info("Starting worker [%s]", pid) - loop.run_forever() - finally: - logger.info("Stopping worker [%s]", pid) - - # Run the on_stop function if provided - loop.run_until_complete(app._server_event("shutdown", "before")) - + def _cleanup(): # Wait for event loop to finish and all connections to drain http_server.close() loop.run_until_complete(http_server.wait_closed()) @@ -188,8 +252,51 @@ def serve( conn.websocket.fail_connection(code=1001) else: conn.abort() - loop.run_until_complete(app._server_event("shutdown", "after")) - remove_unix_socket(unix) + + _setup_system_signals(app, run_multiple, register_sys_signals, loop) + loop.run_until_complete(app._server_event("init", "after")) + _run_server_forever( + loop, + partial(app._server_event, "shutdown", "before"), + partial(app._server_event, "shutdown", "after"), + _cleanup, + unix, + ) + + +def _serve_http_3( + host, + port, + app, + loop, + ssl, + register_sys_signals: bool = True, + run_multiple: bool = False, +): + protocol = partial(Http3Protocol, app=app) + ticket_store = SessionTicketStore() + ssl_context = get_ssl_context(app, ssl) + config = get_config(app, ssl_context) + coro = quic_serve( + host, + port, + configuration=config, + create_protocol=protocol, + session_ticket_fetcher=ticket_store.pop, + session_ticket_handler=ticket_store.add, + ) + server = AsyncioServer(app, loop, coro, []) + loop.run_until_complete(server.startup()) + loop.run_until_complete(server.before_start()) + loop.run_until_complete(server) + _setup_system_signals(app, run_multiple, register_sys_signals, loop) + loop.run_until_complete(server.after_start()) + + # TODO: Create connection cleanup and graceful shutdown + cleanup = None + _run_server_forever( + loop, server.before_stop, server.after_stop, cleanup, None + ) def serve_single(server_settings): diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index e4972516a3..b31e93c115 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -9,7 +9,7 @@ from sanic.exceptions import ServerError -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from .impl import WebsocketImplProtocol UTF8Decoder = codecs.getincrementaldecoder("utf-8") @@ -37,7 +37,7 @@ class WebsocketFrameAssembler: "get_id", "put_id", ) - if TYPE_CHECKING: # no cov + if TYPE_CHECKING: protocol: "WebsocketImplProtocol" read_mutex: asyncio.Lock write_mutex: asyncio.Lock diff --git a/sanic/signals.py b/sanic/signals.py index d62a117c52..80c6300b73 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -6,9 +6,9 @@ from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union, cast -from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore -from sanic_routing.exceptions import NotFound # type: ignore -from sanic_routing.utils import path_to_parts # type: ignore +from sanic_routing import BaseRouter, Route, RouteGroup +from sanic_routing.exceptions import NotFound +from sanic_routing.utils import path_to_parts from sanic.exceptions import InvalidSignal from sanic.log import error_logger, logger diff --git a/sanic/touchup/schemes/__init__.py b/sanic/touchup/schemes/__init__.py index 87057a5fce..dd4145abad 100644 --- a/sanic/touchup/schemes/__init__.py +++ b/sanic/touchup/schemes/__init__.py @@ -1,3 +1,4 @@ +from .altsvc import AltSvcCheck # noqa from .base import BaseScheme from .ode import OptionalDispatchEvent # noqa diff --git a/sanic/touchup/schemes/altsvc.py b/sanic/touchup/schemes/altsvc.py new file mode 100644 index 0000000000..05e7269bbe --- /dev/null +++ b/sanic/touchup/schemes/altsvc.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from ast import Assign, Constant, NodeTransformer, Subscript +from typing import TYPE_CHECKING, Any, List + +from sanic.http.constants import HTTP + +from .base import BaseScheme + + +if TYPE_CHECKING: + from sanic import Sanic + + +class AltSvcCheck(BaseScheme): + ident = "ALTSVC" + + def visitors(self) -> List[NodeTransformer]: + return [RemoveAltSvc(self.app, self.app.state.verbosity)] + + +class RemoveAltSvc(NodeTransformer): + def __init__(self, app: Sanic, verbosity: int = 0) -> None: + self._app = app + self._verbosity = verbosity + self._versions = { + info.settings["version"] for info in app.state.server_info + } + + def visit_Assign(self, node: Assign) -> Any: + if any(self._matches(target) for target in node.targets): + if self._should_remove(): + return None + assert isinstance(node.value, Constant) + node.value.value = self.value() + return node + + def _should_remove(self) -> bool: + return len(self._versions) == 1 + + @staticmethod + def _matches(node) -> bool: + return ( + isinstance(node, Subscript) + and isinstance(node.slice, Constant) + and node.slice.value == "alt-svc" + ) + + def value(self): + values = [] + for info in self._app.state.server_info: + port = info.settings["port"] + version = info.settings["version"] + if version is HTTP.VERSION_3: + values.append(f'h3=":{port}"') + return ", ".join(values) diff --git a/sanic/touchup/schemes/base.py b/sanic/touchup/schemes/base.py index d16619b2f8..9e32c32371 100644 --- a/sanic/touchup/schemes/base.py +++ b/sanic/touchup/schemes/base.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod -from typing import Set, Type +from ast import NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any, Dict, List, Set, Type class BaseScheme(ABC): @@ -10,11 +13,26 @@ def __init__(self, app) -> None: self.app = app @abstractmethod - def run(self, method, module_globals) -> None: + def visitors(self) -> List[NodeTransformer]: ... def __init_subclass__(cls): BaseScheme._registry.add(cls) - def __call__(self, method, module_globals): - return self.run(method, module_globals) + def __call__(self): + return self.visitors() + + @classmethod + def build(cls, method, module_globals, app): + raw_source = getsource(method) + src = dedent(raw_source) + node = parse(src) + + for scheme in cls._registry: + for visitor in scheme(app)(): + node = visitor.visit(node) + + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + return exec_locals[method.__name__] diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py index 6303ed17cf..ae51df7264 100644 --- a/sanic/touchup/schemes/ode.py +++ b/sanic/touchup/schemes/ode.py @@ -1,7 +1,5 @@ -from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse -from inspect import getsource -from textwrap import dedent -from typing import Any +from ast import Attribute, Await, Expr, NodeTransformer +from typing import Any, List from sanic.log import logger @@ -20,16 +18,8 @@ def __init__(self, app) -> None: signal.name for signal in app.signal_router.routes ] - def run(self, method, module_globals): - raw_source = getsource(method) - src = dedent(raw_source) - tree = parse(src) - node = RemoveDispatch(self._registered_events).visit(tree) - compiled_src = compile(node, method.__name__, "exec") - exec_locals: Dict[str, Any] = {} - exec(compiled_src, module_globals, exec_locals) # nosec - - return exec_locals[method.__name__] + def visitors(self) -> List[NodeTransformer]: + return [RemoveDispatch(self._registered_events)] def _sync_events(self): all_events = set() diff --git a/sanic/touchup/service.py b/sanic/touchup/service.py index 95792dca10..b1b996fb5b 100644 --- a/sanic/touchup/service.py +++ b/sanic/touchup/service.py @@ -21,10 +21,8 @@ def run(cls, app): module = getmodule(target) module_globals = dict(getmembers(module)) - - for scheme in BaseScheme._registry: - modified = scheme(app)(method, module_globals) - setattr(target, method_name, modified) + modified = BaseScheme.build(method, module_globals, app) + setattr(target, method_name, modified) target.__touched__ = True diff --git a/sanic/views.py b/sanic/views.py index 23cd110d22..627f20680e 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -13,7 +13,7 @@ from sanic.models.handler_types import RouteHandler -if TYPE_CHECKING: # no cov +if TYPE_CHECKING: from sanic import Sanic from sanic.blueprints import Blueprint diff --git a/setup.py b/setup.py index 966006d966..9579ad26b9 100644 --- a/setup.py +++ b/setup.py @@ -149,6 +149,7 @@ def open_local(paths, mode="r", encoding="utf8"): "docs": docs_require, "all": all_require, "ext": ["sanic-ext"], + "http3": ["aioquic"], } setup_kwargs["install_requires"] = requirements diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/asyncmock.py b/tests/asyncmock.py index eec1764664..835012fd8f 100644 --- a/tests/asyncmock.py +++ b/tests/asyncmock.py @@ -25,6 +25,10 @@ async def dummy(): def __await__(self): return self().__await__() + def reset_mock(self, *args, **kwargs): + super().reset_mock(*args, **kwargs) + self.await_count = 0 + def assert_awaited_once(self): if not self.await_count == 1: msg = ( @@ -32,3 +36,13 @@ def assert_awaited_once(self): f" Awaited {self.await_count} times." ) raise AssertionError(msg) + + def assert_awaited_once_with(self, *args, **kwargs): + if not self.await_count == 1: + msg = ( + f"Expected to have been awaited once." + f" Awaited {self.await_count} times." + ) + raise AssertionError(msg) + self.assert_awaited_once() + return self.assert_called_with(*args, **kwargs) diff --git a/tests/client.py b/tests/client.py new file mode 100644 index 0000000000..4c0b29a0fd --- /dev/null +++ b/tests/client.py @@ -0,0 +1,47 @@ +import asyncio + +from textwrap import dedent +from typing import AnyStr + + +class RawClient: + CRLF = b"\r\n" + + def __init__(self, host: str, port: int): + self.reader = None + self.writer = None + self.host = host + self.port = port + + async def connect(self): + self.reader, self.writer = await asyncio.open_connection( + self.host, self.port + ) + + async def close(self): + self.writer.close() + await self.writer.wait_closed() + + async def send(self, message: AnyStr): + if isinstance(message, str): + msg = self._clean(message).encode("utf-8") + else: + msg = message + await self._send(msg) + + async def _send(self, message: bytes): + if not self.writer: + raise Exception("No open write stream") + self.writer.write(message) + + async def recv(self, nbytes: int = -1) -> bytes: + if not self.reader: + raise Exception("No open read stream") + return await self.reader.read(nbytes) + + def _clean(self, message: str) -> str: + return ( + dedent(message) + .lstrip("\n") + .replace("\n", self.CRLF.decode("utf-8")) + ) diff --git a/tests/conftest.py b/tests/conftest.py index fe4ba47d62..b10ed736c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -150,6 +150,7 @@ def app(request): yield app for target, method_name in TouchUp._registry: setattr(target, method_name, CACHE[method_name]) + Sanic._app_registry.clear() @pytest.fixture(scope="function") diff --git a/tests/http3/__init__.py b/tests/http3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/http3/test_http_receiver.py b/tests/http3/test_http_receiver.py new file mode 100644 index 0000000000..2784e87b30 --- /dev/null +++ b/tests/http3/test_http_receiver.py @@ -0,0 +1,294 @@ +from unittest.mock import Mock + +import pytest + +from aioquic.h3.connection import H3Connection +from aioquic.h3.events import DataReceived, HeadersReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection +from aioquic.quic.events import ProtocolNegotiated + +from sanic import Request, Sanic +from sanic.compat import Header +from sanic.config import DEFAULT_CONFIG +from sanic.exceptions import PayloadTooLarge +from sanic.http.constants import Stage +from sanic.http.http3 import Http3, HTTPReceiver +from sanic.models.server_types import ConnInfo +from sanic.response import empty, json +from sanic.server.protocols.http_protocol import Http3Protocol + + +try: + from unittest.mock import AsyncMock +except ImportError: + from tests.asyncmock import AsyncMock # type: ignore + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(autouse=True) +async def setup(app: Sanic): + @app.get("/") + async def handler(*_): + return empty() + + app.router.finalize() + app.signal_router.finalize() + app.signal_router.allow_fail_builtin = False + + +@pytest.fixture +def http_request(app): + return Request(b"/", Header({}), "3", "GET", Mock(), app) + + +def generate_protocol(app): + connection = QuicConnection(configuration=QuicConfiguration()) + connection._ack_delay = 0 + connection._loss = Mock() + connection._loss.spaces = [] + connection._loss.get_loss_detection_time = lambda: None + connection.datagrams_to_send = Mock(return_value=[]) # type: ignore + return Http3Protocol( + connection, + app=app, + stream_handler=None, + ) + + +def generate_http_receiver(app, http_request) -> HTTPReceiver: + protocol = generate_protocol(app) + receiver = HTTPReceiver( + protocol.transmit, + protocol, + http_request, + ) + http_request.stream = receiver + return receiver + + +def test_http_receiver_init(app: Sanic, http_request: Request): + receiver = generate_http_receiver(app, http_request) + assert receiver.request_body is None + assert receiver.stage is Stage.IDLE + assert receiver.headers_sent is False + assert receiver.response is None + assert receiver.request_max_size == DEFAULT_CONFIG["REQUEST_MAX_SIZE"] + assert receiver.request_bytes == 0 + + +async def test_http_receiver_run_request(app: Sanic, http_request: Request): + handler = AsyncMock() + + class mock_handle(Sanic): + handle_request = handler + + app.__class__ = mock_handle + receiver = generate_http_receiver(app, http_request) + receiver.protocol.quic_event_received( + ProtocolNegotiated(alpn_protocol="h3") + ) + await receiver.run() + handler.assert_awaited_once_with(receiver.request) + + +async def test_http_receiver_run_exception(app: Sanic, http_request: Request): + handler = AsyncMock() + + class mock_handle(Sanic): + handle_exception = handler + + app.__class__ = mock_handle + receiver = generate_http_receiver(app, http_request) + receiver.protocol.quic_event_received( + ProtocolNegotiated(alpn_protocol="h3") + ) + exception = Exception("Oof") + await receiver.run(exception) + handler.assert_awaited_once_with(receiver.request, exception) + + handler.reset_mock() + receiver.stage = Stage.REQUEST + await receiver.run(exception) + handler.assert_awaited_once_with(receiver.request, exception) + + +def test_http_receiver_respond(app: Sanic, http_request: Request): + receiver = generate_http_receiver(app, http_request) + response = empty() + + receiver.stage = Stage.RESPONSE + with pytest.raises(RuntimeError, match="Response already started"): + receiver.respond(response) + + receiver.stage = Stage.HANDLER + receiver.response = Mock() + resp = receiver.respond(response) + + assert receiver.response is resp + assert resp is response + assert response.stream is receiver + + +def test_http_receiver_receive_body(app: Sanic, http_request: Request): + receiver = generate_http_receiver(app, http_request) + receiver.request_max_size = 4 + + receiver.receive_body(b"..") + assert receiver.request.body == b".." + + receiver.receive_body(b"..") + assert receiver.request.body == b"...." + + with pytest.raises( + PayloadTooLarge, match="Request body exceeds the size limit" + ): + receiver.receive_body(b"..") + + +def test_http3_events(app): + protocol = generate_protocol(app) + http3 = Http3(protocol, protocol.transmit) + http3.http_event_received( + HeadersReceived( + [ + (b":method", b"GET"), + (b":path", b"/location"), + (b":scheme", b"https"), + (b":authority", b"localhost:8443"), + (b"foo", b"bar"), + ], + 1, + False, + ) + ) + http3.http_event_received(DataReceived(b"foobar", 1, False)) + receiver = http3.receivers[1] + + assert len(http3.receivers) == 1 + assert receiver.request.stream_id == 1 + assert receiver.request.path == "/location" + assert receiver.request.method == "GET" + assert receiver.request.headers["foo"] == "bar" + assert receiver.request.body == b"foobar" + + +async def test_send_headers(app: Sanic, http_request: Request): + send_headers_mock = Mock() + existing_send_headers = H3Connection.send_headers + receiver = generate_http_receiver(app, http_request) + receiver.protocol.quic_event_received( + ProtocolNegotiated(alpn_protocol="h3") + ) + + http_request._protocol = receiver.protocol + + def send_headers(*args, **kwargs): + send_headers_mock(*args, **kwargs) + return existing_send_headers( + receiver.protocol.connection, *args, **kwargs + ) + + receiver.protocol.connection.send_headers = send_headers + receiver.head_only = False + response = json({}, status=201, headers={"foo": "bar"}) + + with pytest.raises(RuntimeError, match="no response"): + receiver.send_headers() + + receiver.response = response + receiver.send_headers() + + assert receiver.headers_sent + assert receiver.stage is Stage.RESPONSE + send_headers_mock.assert_called_once_with( + stream_id=0, + headers=[ + (b":status", b"201"), + (b"foo", b"bar"), + (b"content-length", b"2"), + (b"content-type", b"application/json"), + ], + ) + + +def test_multiple_streams(app): + protocol = generate_protocol(app) + http3 = Http3(protocol, protocol.transmit) + http3.http_event_received( + HeadersReceived( + [ + (b":method", b"GET"), + (b":path", b"/location"), + (b":scheme", b"https"), + (b":authority", b"localhost:8443"), + (b"foo", b"bar"), + ], + 1, + False, + ) + ) + http3.http_event_received( + HeadersReceived( + [ + (b":method", b"GET"), + (b":path", b"/location"), + (b":scheme", b"https"), + (b":authority", b"localhost:8443"), + (b"foo", b"bar"), + ], + 2, + False, + ) + ) + + receiver1 = http3.get_receiver_by_stream_id(1) + receiver2 = http3.get_receiver_by_stream_id(2) + assert len(http3.receivers) == 2 + assert isinstance(receiver1, HTTPReceiver) + assert isinstance(receiver2, HTTPReceiver) + assert receiver1 is not receiver2 + + +def test_request_stream_id(app): + protocol = generate_protocol(app) + http3 = Http3(protocol, protocol.transmit) + http3.http_event_received( + HeadersReceived( + [ + (b":method", b"GET"), + (b":path", b"/location"), + (b":scheme", b"https"), + (b":authority", b"localhost:8443"), + (b"foo", b"bar"), + ], + 1, + False, + ) + ) + receiver = http3.get_receiver_by_stream_id(1) + + assert isinstance(receiver.request, Request) + assert receiver.request.stream_id == 1 + + +def test_request_conn_info(app): + protocol = generate_protocol(app) + http3 = Http3(protocol, protocol.transmit) + http3.http_event_received( + HeadersReceived( + [ + (b":method", b"GET"), + (b":path", b"/location"), + (b":scheme", b"https"), + (b":authority", b"localhost:8443"), + (b"foo", b"bar"), + ], + 1, + False, + ) + ) + receiver = http3.get_receiver_by_stream_id(1) + + assert isinstance(receiver.request.conn_info, ConnInfo) diff --git a/tests/http3/test_server.py b/tests/http3/test_server.py new file mode 100644 index 0000000000..bed2446a25 --- /dev/null +++ b/tests/http3/test_server.py @@ -0,0 +1,114 @@ +import logging +import sys + +from asyncio import Event +from pathlib import Path + +import pytest + +from sanic import Sanic +from sanic.compat import UVLOOP_INSTALLED +from sanic.http.constants import HTTP + + +parent_dir = Path(__file__).parent.parent +localhost_dir = parent_dir / "certs/localhost" + + +@pytest.mark.parametrize("version", (3, HTTP.VERSION_3)) +@pytest.mark.skipif( + sys.version_info < (3, 8) and not UVLOOP_INSTALLED, + reason="In 3.7 w/o uvloop the port is not always released", +) +def test_server_starts_http3(app: Sanic, version, caplog): + ev = Event() + + @app.after_server_start + def shutdown(*_): + ev.set() + app.stop() + + with caplog.at_level(logging.INFO): + app.run( + version=version, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + ) + + assert ev.is_set() + assert ( + "sanic.root", + logging.INFO, + "server: sanic, HTTP/3", + ) in caplog.record_tuples + + +@pytest.mark.skipif( + sys.version_info < (3, 8) and not UVLOOP_INSTALLED, + reason="In 3.7 w/o uvloop the port is not always released", +) +def test_server_starts_http1_and_http3(app: Sanic, caplog): + @app.after_server_start + def shutdown(*_): + app.stop() + + app.prepare( + version=3, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + ) + app.prepare( + version=1, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + ) + with caplog.at_level(logging.INFO): + Sanic.serve() + + assert ( + "sanic.root", + logging.INFO, + "server: sanic, HTTP/1.1", + ) in caplog.record_tuples + assert ( + "sanic.root", + logging.INFO, + "server: sanic, HTTP/3", + ) in caplog.record_tuples + + +@pytest.mark.skipif( + sys.version_info < (3, 8) and not UVLOOP_INSTALLED, + reason="In 3.7 w/o uvloop the port is not always released", +) +def test_server_starts_http1_and_http3_bad_order(app: Sanic, caplog): + @app.after_server_start + def shutdown(*_): + app.stop() + + app.prepare( + version=1, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + ) + message = ( + "Serving HTTP/3 instances as a secondary server is not supported. " + "There can only be a single HTTP/3 worker and it must be the first " + "instance prepared." + ) + with pytest.raises(RuntimeError, match=message): + app.prepare( + version=3, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + ) diff --git a/tests/http3/test_session_ticket_store.py b/tests/http3/test_session_ticket_store.py new file mode 100644 index 0000000000..c7fa29c486 --- /dev/null +++ b/tests/http3/test_session_ticket_store.py @@ -0,0 +1,46 @@ +from datetime import datetime + +from aioquic.tls import CipherSuite, SessionTicket + +from sanic.http.http3 import SessionTicketStore + + +def _generate_ticket(label): + return SessionTicket( + 1, + CipherSuite.AES_128_GCM_SHA256, + datetime.now(), + datetime.now(), + label, + label.decode(), + label, + None, + [], + ) + + +def test_session_ticket_store(): + store = SessionTicketStore() + + assert len(store.tickets) == 0 + + ticket1 = _generate_ticket(b"foo") + store.add(ticket1) + + assert len(store.tickets) == 1 + + ticket2 = _generate_ticket(b"bar") + store.add(ticket2) + + assert len(store.tickets) == 2 + assert len(store.tickets) == 2 + + popped2 = store.pop(ticket2.ticket) + + assert len(store.tickets) == 1 + assert popped2 is ticket2 + + popped1 = store.pop(ticket1.ticket) + + assert len(store.tickets) == 0 + assert popped1 is ticket1 diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 61d36fa732..4c473047a0 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -417,7 +417,7 @@ async def test_request_class_custom(): class MyCustomRequest(Request): pass - app = Sanic(name=__name__, request_class=MyCustomRequest) + app = Sanic(name="Test", request_class=MyCustomRequest) @app.get("/custom") def custom_request(request): diff --git a/tests/test_cli.py b/tests/test_cli.py index 293965770d..24cc4777f1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -148,8 +148,7 @@ def test_tls_wrong_options(cmd: Tuple[str]): assert not out lines = err.decode().split("\n") - errmsg = lines[6] - assert errmsg == "TLS certificates must be specified by either of:" + assert "TLS certificates must be specified by either of:" in lines @pytest.mark.parametrize( diff --git a/tests/test_config.py b/tests/test_config.py index d8a7bd85b3..764d7940b5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import logging +import os from contextlib import contextmanager from os import environ @@ -13,6 +14,7 @@ from sanic import Sanic from sanic.config import DEFAULT_CONFIG, Config +from sanic.constants import LocalCertCreator from sanic.exceptions import PyFileError @@ -49,7 +51,7 @@ def test_load_from_object(app: Sanic): def test_load_from_object_string(app: Sanic): - app.config.load("test_config.ConfigTest") + app.config.load("tests.test_config.ConfigTest") assert "CONFIG_VALUE" in app.config assert app.config.CONFIG_VALUE == "should be used" assert "not_for_config" not in app.config @@ -71,14 +73,14 @@ def test_load_from_object_string_exception(app: Sanic): def test_auto_env_prefix(): environ["SANIC_TEST_ANSWER"] = "42" - app = Sanic(name=__name__) + app = Sanic(name="Test") assert app.config.TEST_ANSWER == 42 del environ["SANIC_TEST_ANSWER"] def test_auto_bool_env_prefix(): environ["SANIC_TEST_ANSWER"] = "True" - app = Sanic(name=__name__) + app = Sanic(name="Test") assert app.config.TEST_ANSWER is True del environ["SANIC_TEST_ANSWER"] @@ -86,28 +88,28 @@ def test_auto_bool_env_prefix(): @pytest.mark.parametrize("env_prefix", [None, ""]) def test_empty_load_env_prefix(env_prefix): environ["SANIC_TEST_ANSWER"] = "42" - app = Sanic(name=__name__, env_prefix=env_prefix) + app = Sanic(name="Test", env_prefix=env_prefix) assert getattr(app.config, "TEST_ANSWER", None) is None del environ["SANIC_TEST_ANSWER"] def test_env_prefix(): environ["MYAPP_TEST_ANSWER"] = "42" - app = Sanic(name=__name__, env_prefix="MYAPP_") + app = Sanic(name="Test", env_prefix="MYAPP_") assert app.config.TEST_ANSWER == 42 del environ["MYAPP_TEST_ANSWER"] def test_env_prefix_float_values(): environ["MYAPP_TEST_ROI"] = "2.3" - app = Sanic(name=__name__, env_prefix="MYAPP_") + app = Sanic(name="Test", env_prefix="MYAPP_") assert app.config.TEST_ROI == 2.3 del environ["MYAPP_TEST_ROI"] def test_env_prefix_string_value(): environ["MYAPP_TEST_TOKEN"] = "somerandomtesttoken" - app = Sanic(name=__name__, env_prefix="MYAPP_") + app = Sanic(name="Test", env_prefix="MYAPP_") assert app.config.TEST_TOKEN == "somerandomtesttoken" del environ["MYAPP_TEST_TOKEN"] @@ -116,7 +118,7 @@ def test_env_w_custom_converter(): environ["SANIC_TEST_ANSWER"] = "42" config = Config(converters=[UltimateAnswer]) - app = Sanic(name=__name__, config=config) + app = Sanic(name="Test", config=config) assert isinstance(app.config.TEST_ANSWER, UltimateAnswer) assert app.config.TEST_ANSWER.answer == 42 del environ["SANIC_TEST_ANSWER"] @@ -125,7 +127,7 @@ def test_env_w_custom_converter(): def test_env_lowercase(): with pytest.warns(None) as record: environ["SANIC_test_answer"] = "42" - app = Sanic(name=__name__) + app = Sanic(name="Test") assert app.config.test_answer == 42 assert str(record[0].message) == ( "[DEPRECATION v22.9] Lowercase environment variables will not be " @@ -435,3 +437,21 @@ def test_negative_proxy_count(app: Sanic): ) with pytest.raises(ValueError, match=message): app.prepare() + + +@pytest.mark.parametrize( + "passed,expected", + ( + ("auto", LocalCertCreator.AUTO), + ("mkcert", LocalCertCreator.MKCERT), + ("trustme", LocalCertCreator.TRUSTME), + ("AUTO", LocalCertCreator.AUTO), + ("MKCERT", LocalCertCreator.MKCERT), + ("TRUSTME", LocalCertCreator.TRUSTME), + ), +) +def test_convert_local_cert_creator(passed, expected): + os.environ["SANIC_LOCAL_CERT_CREATOR"] = passed + app = Sanic("Test") + assert app.config.LOCAL_CERT_CREATOR is expected + del os.environ["SANIC_LOCAL_CERT_CREATOR"] diff --git a/tests/test_custom_request.py b/tests/test_custom_request.py index 758feebaa6..84ab705a81 100644 --- a/tests/test_custom_request.py +++ b/tests/test_custom_request.py @@ -17,7 +17,7 @@ async def receive_body(self): def test_custom_request(): - app = Sanic(name=__name__, request_class=CustomRequest) + app = Sanic(name="Test", request_class=CustomRequest) @app.route("/post", methods=["POST"]) async def post_handler(request): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5ab8678679..12008ee9d1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -259,7 +259,7 @@ def tempest(_): def test_exception_in_ws_logged(caplog): - app = Sanic(__name__) + app = Sanic("Test") @app.websocket("/feed") async def feed(request, ws): @@ -279,7 +279,7 @@ async def feed(request, ws): @pytest.mark.parametrize("debug", (True, False)) def test_contextual_exception_context(debug): - app = Sanic(__name__) + app = Sanic("Test") class TeapotError(SanicException): status_code = 418 @@ -314,7 +314,7 @@ def fail(): @pytest.mark.parametrize("debug", (True, False)) def test_contextual_exception_extra(debug): - app = Sanic(__name__) + app = Sanic("Test") class TeapotError(SanicException): status_code = 418 @@ -361,7 +361,7 @@ def fail(): @pytest.mark.parametrize("override", (True, False)) def test_contextual_exception_functional_message(override): - app = Sanic(__name__) + app = Sanic("Test") class TeapotError(SanicException): status_code = 418 diff --git a/tests/test_http.py b/tests/test_http.py index 653857a12c..1e385449c5 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,9 +1,7 @@ -import asyncio import json as stdjson from collections import namedtuple -from textwrap import dedent -from typing import AnyStr +from pathlib import Path import pytest @@ -11,52 +9,13 @@ from sanic import json, text from sanic.app import Sanic +from tests.client import RawClient -PORT = 1234 - - -class RawClient: - CRLF = b"\r\n" - - def __init__(self, host: str, port: int): - self.reader = None - self.writer = None - self.host = host - self.port = port - - async def connect(self): - self.reader, self.writer = await asyncio.open_connection( - self.host, self.port - ) +parent_dir = Path(__file__).parent +localhost_dir = parent_dir / "certs/localhost" - async def close(self): - self.writer.close() - await self.writer.wait_closed() - - async def send(self, message: AnyStr): - if isinstance(message, str): - msg = self._clean(message).encode("utf-8") - else: - msg = message - await self._send(msg) - - async def _send(self, message: bytes): - if not self.writer: - raise Exception("No open write stream") - self.writer.write(message) - - async def recv(self, nbytes: int = -1) -> bytes: - if not self.reader: - raise Exception("No open read stream") - return await self.reader.read(nbytes) - - def _clean(self, message: str) -> str: - return ( - dedent(message) - .lstrip("\n") - .replace("\n", self.CRLF.decode("utf-8")) - ) +PORT = 1234 @pytest.fixture @@ -115,7 +74,7 @@ def test_full_message(client): """ ) response = client.recv() - assert len(response) == 140 + assert len(response) == 151 assert b"200 OK" in response diff --git a/tests/test_http_alt_svc.py b/tests/test_http_alt_svc.py new file mode 100644 index 0000000000..62f2b02e51 --- /dev/null +++ b/tests/test_http_alt_svc.py @@ -0,0 +1,66 @@ +import sys + +from pathlib import Path + +import pytest + +from sanic.app import Sanic +from sanic.response import empty +from tests.client import RawClient + + +parent_dir = Path(__file__).parent +localhost_dir = parent_dir / "certs/localhost" + +PORT = 12344 + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Not supported in 3.7") +def test_http1_response_has_alt_svc(): + Sanic._app_registry.clear() + app = Sanic("TestAltSvc") + app.config.TOUCHUP = True + response = b"" + + @app.get("/") + async def handler(*_): + return empty() + + @app.after_server_start + async def do_request(*_): + nonlocal response + + app.router.reset() + app.router.finalize() + + client = RawClient(app.state.host, app.state.port) + await client.connect() + await client.send( + """ + GET / HTTP/1.1 + host: localhost:7777 + + """ + ) + response = await client.recv() + await client.close() + + @app.after_server_start + def shutdown(*_): + app.stop() + + app.prepare( + version=3, + ssl={ + "cert": localhost_dir / "fullchain.pem", + "key": localhost_dir / "privkey.pem", + }, + port=PORT, + ) + app.prepare( + version=1, + port=PORT, + ) + Sanic.serve() + + assert f'alt-svc: h3=":{PORT}"\r\n'.encode() in response diff --git a/tests/test_logging.py b/tests/test_logging.py index 274b407c54..18d456660f 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -136,7 +136,7 @@ async def conn_lost(request): async def test_logger(caplog): rand_string = str(uuid.uuid4()) - app = Sanic(name=__name__) + app = Sanic(name="Test") @app.get("/") def log_info(request): @@ -163,7 +163,7 @@ def test_logging_modified_root_logger_config(): def test_access_log_client_ip_remote_addr(monkeypatch): access = Mock() - monkeypatch.setattr(sanic.http, "access_logger", access) + monkeypatch.setattr(sanic.http.http1, "access_logger", access) app = Sanic("test_logging") app.config.PROXIES_COUNT = 2 @@ -190,7 +190,7 @@ async def handler(request): def test_access_log_client_ip_reqip(monkeypatch): access = Mock() - monkeypatch.setattr(sanic.http, "access_logger", access) + monkeypatch.setattr(sanic.http.http1, "access_logger", access) app = Sanic("test_logging") diff --git a/tests/test_motd.py b/tests/test_motd.py index f3f95a2591..83c7e4bf8c 100644 --- a/tests/test_motd.py +++ b/tests/test_motd.py @@ -53,7 +53,7 @@ def test_motd_with_expected_info(app, run_startup): assert logs[1][2] == f"Sanic v{__version__}" assert logs[3][2] == "mode: debug, single worker" - assert logs[4][2] == "server: sanic" + assert logs[4][2] == "server: sanic, HTTP/1.1" assert logs[5][2] == f"python: {platform.python_version()}" assert logs[6][2] == f"platform: {platform.platform()}" diff --git a/tests/test_multi_serve.py b/tests/test_multi_serve.py index dde72b5c26..be7960e6b4 100644 --- a/tests/test_multi_serve.py +++ b/tests/test_multi_serve.py @@ -14,7 +14,7 @@ try: from unittest.mock import AsyncMock except ImportError: - from asyncmock import AsyncMock # type: ignore + from tests.asyncmock import AsyncMock # type: ignore @pytest.fixture diff --git a/tests/test_request.py b/tests/test_request.py index cb68325f46..e01cc1e724 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -231,3 +231,15 @@ async def get(request): _, resp = app.test_client.get("/") assert resp.json["same"] + + +def test_request_stream_id(app): + @app.get("/") + async def get(request): + try: + request.stream_id + except Exception as e: + return response.text(str(e)) + + _, resp = app.test_client.get("/") + assert resp.text == "Stream ID is only a property of a HTTP/3 request" diff --git a/tests/test_request_stream.py b/tests/test_request_stream.py index 37cbc04b54..8bf1fa5705 100644 --- a/tests/test_request_stream.py +++ b/tests/test_request_stream.py @@ -552,7 +552,7 @@ async def handler(request): def test_streaming_echo(): """2-way streaming chat between server and client.""" - app = Sanic(name=__name__) + app = Sanic(name="Test") @app.post("/echo", stream=True) async def handler(request): diff --git a/tests/test_requests.py b/tests/test_requests.py index 4d7fb0aa13..9a984bb46e 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -2050,7 +2050,7 @@ async def post(request): def test_endpoint_basic(): - app = Sanic(name=__name__) + app = Sanic(name="Test") @app.route("/") def my_unique_handler(request): @@ -2058,12 +2058,12 @@ def my_unique_handler(request): request, response = app.test_client.get("/") - assert request.endpoint == "test_requests.my_unique_handler" + assert request.endpoint == "Test.my_unique_handler" @pytest.mark.asyncio async def test_endpoint_basic_asgi(): - app = Sanic(name=__name__) + app = Sanic(name="Test") @app.route("/") def my_unique_handler(request): @@ -2071,7 +2071,7 @@ def my_unique_handler(request): request, response = await app.asgi_client.get("/") - assert request.endpoint == "test_requests.my_unique_handler" + assert request.endpoint == "Test.my_unique_handler" def test_endpoint_named_app(): diff --git a/tests/test_response.py b/tests/test_response.py index e1c19d2ab9..a25a1318e4 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -101,11 +101,12 @@ async def test(request: Request): return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"}) request, response = app.test_client.get("/") - assert dict(response.headers) == { + for key, value in { "connection": "keep-alive", "content-length": "11", "content-type": "application/json", - } + }.items(): + assert response.headers[key] == value def test_response_content_length(app): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 8396a9f964..1d5283197d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -13,7 +13,7 @@ try: from unittest.mock import AsyncMock except ImportError: - from asyncmock import AsyncMock # type: ignore + from tests.asyncmock import AsyncMock # type: ignore pytestmark = pytest.mark.asyncio diff --git a/tests/test_tls.py b/tests/test_tls.py index b0674be49e..cc8eb3f92f 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -1,18 +1,30 @@ import logging import os import ssl -import uuid +import subprocess from contextlib import contextmanager +from pathlib import Path +from unittest.mock import Mock, patch from urllib.parse import urlparse import pytest -from sanic_testing.testing import HOST, PORT, SanicTestClient +from sanic_testing.testing import HOST, PORT + +import sanic.http.tls.creators from sanic import Sanic -from sanic.compat import OS_IS_WINDOWS -from sanic.log import logger +from sanic.application.constants import Mode +from sanic.constants import LocalCertCreator +from sanic.exceptions import SanicException +from sanic.helpers import _default +from sanic.http.tls.context import SanicSSLContext +from sanic.http.tls.creators import ( + MkcertCreator, + TrustmeCreator, + get_ssl_context, +) from sanic.response import text @@ -26,9 +38,63 @@ sanic_key = os.path.join(sanic_dir, "privkey.pem") +@pytest.fixture +def server_cert(): + return Mock() + + +@pytest.fixture +def issue_cert(server_cert): + mock = Mock(return_value=server_cert) + return mock + + +@pytest.fixture +def ca(issue_cert): + ca = Mock() + ca.issue_cert = issue_cert + return ca + + +@pytest.fixture +def trustme(ca): + module = Mock() + module.CA = Mock(return_value=ca) + return module + + +@pytest.fixture +def MockMkcertCreator(): + class Creator(MkcertCreator): + SUPPORTED = True + + def check_supported(self): + if not self.SUPPORTED: + raise SanicException("Nope") + + generate_cert = Mock() + + return Creator + + +@pytest.fixture +def MockTrustmeCreator(): + class Creator(TrustmeCreator): + SUPPORTED = True + + def check_supported(self): + if not self.SUPPORTED: + raise SanicException("Nope") + + generate_cert = Mock() + + return Creator + + @contextmanager def replace_server_name(hostname): - """Temporarily replace the server name sent with all TLS requests with a fake hostname.""" + """Temporarily replace the server name sent with all TLS requests with + a fake hostname.""" def hack_wrap_bio( self, @@ -69,8 +135,7 @@ async def handler(request): app.add_route(handler, path) - port = app.test_client.port - request, response = app.test_client.get( + request, _ = app.test_client.get( f"https://{HOST}:{PORT}" + path + f"?{query}", server_kwargs={"ssl": context}, ) @@ -100,7 +165,7 @@ async def handler(request): app.add_route(handler, path) - request, response = app.test_client.get( + request, _ = app.test_client.get( f"https://{HOST}:{PORT}" + path + f"?{query}", server_kwargs={"ssl": ssl_dict}, ) @@ -116,22 +181,22 @@ async def handler(request): def test_cert_sni_single(app): @app.get("/sni") - async def handler(request): + async def handler1(request): return text(request.conn_info.server_name) @app.get("/commonname") - async def handler(request): + async def handler2(request): return text(request.conn_info.cert.get("commonName")) port = app.test_client.port - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://localhost:{port}/sni", server_kwargs={"ssl": localhost_dir}, ) assert response.status == 200 assert response.text == "localhost" - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://localhost:{port}/commonname", server_kwargs={"ssl": localhost_dir}, ) @@ -143,16 +208,16 @@ def test_cert_sni_list(app): ssl_list = [sanic_dir, localhost_dir] @app.get("/sni") - async def handler(request): + async def handler1(request): return text(request.conn_info.server_name) @app.get("/commonname") - async def handler(request): + async def handler2(request): return text(request.conn_info.cert.get("commonName")) # This test should match the localhost cert port = app.test_client.port - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://localhost:{port}/sni", server_kwargs={"ssl": ssl_list}, ) @@ -168,14 +233,14 @@ async def handler(request): # This part should use the sanic.example cert because it matches with replace_server_name("www.sanic.example"): - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) assert response.status == 200 assert response.text == "www.sanic.example" - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/commonname", server_kwargs={"ssl": ssl_list}, ) @@ -184,14 +249,14 @@ async def handler(request): # This part should use the sanic.example cert, that being the first listed with replace_server_name("invalid.test"): - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) assert response.status == 200 assert response.text == "invalid.test" - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/commonname", server_kwargs={"ssl": ssl_list}, ) @@ -200,7 +265,8 @@ async def handler(request): def test_missing_sni(app): - """The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway.""" + """The sanic cert does not list 127.0.0.1 and httpx does not send + IP as SNI anyway.""" ssl_list = [None, sanic_dir] @app.get("/sni") @@ -209,7 +275,7 @@ async def handler(request): port = app.test_client.port with pytest.raises(Exception) as exc: - request, response = app.test_client.get( + app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) @@ -217,7 +283,8 @@ async def handler(request): def test_no_matching_cert(app): - """The sanic cert does not list 127.0.0.1 and httpx does not send IP as SNI anyway.""" + """The sanic cert does not list 127.0.0.1 and httpx does not send + IP as SNI anyway.""" ssl_list = [None, sanic_dir] @app.get("/sni") @@ -227,7 +294,7 @@ async def handler(request): port = app.test_client.port with replace_server_name("invalid.test"): with pytest.raises(Exception) as exc: - request, response = app.test_client.get( + app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) @@ -244,7 +311,7 @@ async def handler(request): port = app.test_client.port with replace_server_name("foo.sanic.test"): - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) @@ -253,14 +320,14 @@ async def handler(request): with replace_server_name("sanic.test"): with pytest.raises(Exception) as exc: - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) assert "Request and response object expected" in str(exc.value) with replace_server_name("sub.foo.sanic.test"): with pytest.raises(Exception) as exc: - request, response = app.test_client.get( + _, response = app.test_client.get( f"https://127.0.0.1:{port}/sni", server_kwargs={"ssl": ssl_list}, ) @@ -275,9 +342,7 @@ async def handler(request): ssl_dict = {"cert": None, "key": None} with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": ssl_dict} - ) + app.test_client.get("/test", server_kwargs={"ssl": ssl_dict}) assert str(excinfo.value) == "SSL dict needs filenames for cert and key." @@ -288,9 +353,7 @@ async def handler(request): return text("ssl test") with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": False} - ) + app.test_client.get("/test", server_kwargs={"ssl": False}) assert "Invalid ssl argument" in str(excinfo.value) @@ -303,9 +366,7 @@ async def handler(request): ssl_list = [sanic_cert] with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": ssl_list} - ) + app.test_client.get("/test", server_kwargs={"ssl": ssl_list}) assert "folder expected" in str(excinfo.value) assert sanic_cert in str(excinfo.value) @@ -319,9 +380,7 @@ async def handler(request): ssl_list = [invalid_dir] with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": ssl_list} - ) + app.test_client.get("/test", server_kwargs={"ssl": ssl_list}) assert "not found" in str(excinfo.value) assert invalid_dir + "/privkey.pem" in str(excinfo.value) @@ -336,9 +395,7 @@ async def handler(request): ssl_list = [invalid2] with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": ssl_list} - ) + app.test_client.get("/test", server_kwargs={"ssl": ssl_list}) assert "not found" in str(excinfo.value) assert invalid2 + "/fullchain.pem" in str(excinfo.value) @@ -352,15 +409,13 @@ async def handler(request): ssl_list = [None] with pytest.raises(ValueError) as excinfo: - request, response = app.test_client.get( - "/test", server_kwargs={"ssl": ssl_list} - ) + app.test_client.get("/test", server_kwargs={"ssl": ssl_list}) assert "No certificates" in str(excinfo.value) def test_logger_vhosts(caplog): - app = Sanic(name=__name__) + app = Sanic(name="test_logger_vhosts") @app.after_server_start def stop(*args): @@ -374,5 +429,210 @@ def stop(*args): ][0] assert logmsg == ( - "Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, sanic.example, www.sanic.example, *.sanic.test, 2001:DB8:0:0:0:0:0:541C" + "Certificate vhosts: localhost, 127.0.0.1, 0:0:0:0:0:0:0:1, " + "sanic.example, www.sanic.example, *.sanic.test, " + "2001:DB8:0:0:0:0:0:541C" + ) + + +def test_mk_cert_creator_default(app: Sanic): + cert_creator = MkcertCreator(app, _default, _default) + assert isinstance(cert_creator.tmpdir, Path) + assert cert_creator.tmpdir.exists() + + +def test_mk_cert_creator_is_supported(app): + cert_creator = MkcertCreator(app, _default, _default) + with patch("subprocess.run") as run: + cert_creator.check_supported() + run.assert_called_once_with( + ["mkcert", "-help"], + check=True, + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + ) + + +def test_mk_cert_creator_is_not_supported(app): + cert_creator = MkcertCreator(app, _default, _default) + with patch("subprocess.run") as run: + run.side_effect = Exception("") + with pytest.raises( + SanicException, match="Sanic is attempting to use mkcert" + ): + cert_creator.check_supported() + + +def test_mk_cert_creator_generate_cert_default(app): + cert_creator = MkcertCreator(app, _default, _default) + with patch("subprocess.run") as run: + with patch("sanic.http.tls.creators.CertSimple"): + retval = Mock() + retval.stdout = "foo" + run.return_value = retval + cert_creator.generate_cert("localhost") + run.assert_called_once() + + +def test_mk_cert_creator_generate_cert_localhost(app): + cert_creator = MkcertCreator(app, localhost_key, localhost_cert) + with patch("subprocess.run") as run: + with patch("sanic.http.tls.creators.CertSimple"): + cert_creator.generate_cert("localhost") + run.assert_not_called() + + +def test_trustme_creator_default(app: Sanic): + cert_creator = TrustmeCreator(app, _default, _default) + assert isinstance(cert_creator.tmpdir, Path) + assert cert_creator.tmpdir.exists() + + +def test_trustme_creator_is_supported(app, monkeypatch): + monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", True) + cert_creator = TrustmeCreator(app, _default, _default) + cert_creator.check_supported() + + +def test_trustme_creator_is_not_supported(app, monkeypatch): + monkeypatch.setattr(sanic.http.tls.creators, "TRUSTME_INSTALLED", False) + cert_creator = TrustmeCreator(app, _default, _default) + with pytest.raises( + SanicException, match="Sanic is attempting to use trustme" + ): + cert_creator.check_supported() + + +def test_trustme_creator_generate_cert_default( + app, monkeypatch, trustme, issue_cert, server_cert, ca +): + monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme) + cert_creator = TrustmeCreator(app, _default, _default) + cert = cert_creator.generate_cert("localhost") + + assert isinstance(cert, SanicSSLContext) + trustme.CA.assert_called_once_with() + issue_cert.assert_called_once_with("localhost") + server_cert.configure_cert.assert_called_once() + ca.configure_trust.assert_called_once() + ca.cert_pem.write_to_path.assert_called_once_with(str(cert.sanic["cert"])) + write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path + write_to_path.assert_called_once_with(str(cert.sanic["key"])) + + +def test_trustme_creator_generate_cert_localhost( + app, monkeypatch, trustme, server_cert, ca +): + monkeypatch.setattr(sanic.http.tls.creators, "trustme", trustme) + cert_creator = TrustmeCreator(app, localhost_key, localhost_cert) + cert_creator.generate_cert("localhost") + + ca.cert_pem.write_to_path.assert_called_once_with(localhost_cert) + write_to_path = server_cert.private_key_and_cert_chain_pem.write_to_path + write_to_path.assert_called_once_with(localhost_key) + + +def test_get_ssl_context_with_ssl_context(app): + mock_context = Mock() + context = get_ssl_context(app, mock_context) + assert context is mock_context + + +def test_get_ssl_context_in_production(app): + app.state.mode = Mode.PRODUCTION + with pytest.raises( + SanicException, + match="Cannot run Sanic as an HTTPS server in PRODUCTION mode", + ): + get_ssl_context(app, None) + + +@pytest.mark.parametrize( + "requirement,mk_supported,trustme_supported,mk_called,trustme_called,err", + ( + (LocalCertCreator.AUTO, True, False, True, False, None), + (LocalCertCreator.AUTO, True, True, True, False, None), + (LocalCertCreator.AUTO, False, True, False, True, None), + ( + LocalCertCreator.AUTO, + False, + False, + False, + False, + "Sanic could not find package to create a TLS certificate", + ), + (LocalCertCreator.MKCERT, True, False, True, False, None), + (LocalCertCreator.MKCERT, True, True, True, False, None), + (LocalCertCreator.MKCERT, False, True, False, False, "Nope"), + (LocalCertCreator.MKCERT, False, False, False, False, "Nope"), + (LocalCertCreator.TRUSTME, True, False, False, False, "Nope"), + (LocalCertCreator.TRUSTME, True, True, False, True, None), + (LocalCertCreator.TRUSTME, False, True, False, True, None), + (LocalCertCreator.TRUSTME, False, False, False, False, "Nope"), + ), +) +def test_get_ssl_context_only_mkcert( + app, + monkeypatch, + MockMkcertCreator, + MockTrustmeCreator, + requirement, + mk_supported, + trustme_supported, + mk_called, + trustme_called, + err, +): + app.state.mode = Mode.DEBUG + app.config.LOCAL_CERT_CREATOR = requirement + monkeypatch.setattr( + sanic.http.tls.creators, "MkcertCreator", MockMkcertCreator + ) + monkeypatch.setattr( + sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator + ) + MockMkcertCreator.SUPPORTED = mk_supported + MockTrustmeCreator.SUPPORTED = trustme_supported + + if err: + with pytest.raises(SanicException, match=err): + get_ssl_context(app, None) + else: + get_ssl_context(app, None) + + if mk_called: + MockMkcertCreator.generate_cert.assert_called_once_with("localhost") + else: + MockMkcertCreator.generate_cert.assert_not_called() + if trustme_called: + MockTrustmeCreator.generate_cert.assert_called_once_with("localhost") + else: + MockTrustmeCreator.generate_cert.assert_not_called() + + +def test_no_http3_with_trustme( + app, + monkeypatch, + MockTrustmeCreator, +): + monkeypatch.setattr( + sanic.http.tls.creators, "TrustmeCreator", MockTrustmeCreator ) + MockTrustmeCreator.SUPPORTED = True + app.config.LOCAL_CERT_CREATOR = "TRUSTME" + with pytest.raises( + SanicException, + match=( + "Sorry, you cannot currently use trustme as a local certificate " + "generator for an HTTP/3 server" + ), + ): + app.run(version=3, debug=True) + + +def test_sanic_ssl_context_create(): + context = ssl.SSLContext() + sanic_context = SanicSSLContext.create_from_ssl_context(context) + + assert sanic_context is context + assert isinstance(sanic_context, SanicSSLContext) diff --git a/tests/test_unix_socket.py b/tests/test_unix_socket.py index 12b286b400..aa4cd68560 100644 --- a/tests/test_unix_socket.py +++ b/tests/test_unix_socket.py @@ -53,7 +53,7 @@ def test_unix_socket_creation(caplog): assert os.path.exists(SOCKPATH) ino = os.stat(SOCKPATH).st_ino - app = Sanic(name=__name__) + app = Sanic(name="test") @app.listener("after_server_start") def running(app, loop): @@ -74,7 +74,7 @@ def running(app, loop): @pytest.mark.parametrize("path", (".", "no-such-directory/sanictest.sock")) def test_invalid_paths(path): - app = Sanic(name=__name__) + app = Sanic(name="test") with pytest.raises((FileExistsError, FileNotFoundError)): app.run(unix=path) @@ -84,7 +84,7 @@ def test_dont_replace_file(): with open(SOCKPATH, "w") as f: f.write("File, not socket") - app = Sanic(name=__name__) + app = Sanic(name="test") @app.listener("after_server_start") def stop(app, loop): @@ -101,7 +101,7 @@ def test_dont_follow_symlink(): sock.bind(SOCKPATH2) os.symlink(SOCKPATH2, SOCKPATH) - app = Sanic(name=__name__) + app = Sanic(name="test") @app.listener("after_server_start") def stop(app, loop): @@ -112,7 +112,7 @@ def stop(app, loop): def test_socket_deleted_while_running(): - app = Sanic(name=__name__) + app = Sanic(name="test") @app.listener("after_server_start") async def hack(app, loop): @@ -123,7 +123,7 @@ async def hack(app, loop): def test_socket_replaced_with_file(): - app = Sanic(name=__name__) + app = Sanic(name="test") @app.listener("after_server_start") async def hack(app, loop): @@ -136,7 +136,7 @@ async def hack(app, loop): def test_unix_connection(): - app = Sanic(name=__name__) + app = Sanic(name="test") @app.get("/") def handler(request): @@ -159,7 +159,7 @@ async def client(app, loop): app.run(host="myhost.invalid", unix=SOCKPATH) -app_multi = Sanic(name=__name__) +app_multi = Sanic(name="test") def handler(request): diff --git a/tests/test_versioning.py b/tests/test_versioning.py index 396629b301..26f4ca4c22 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -19,7 +19,7 @@ def test_route(app, handler): def test_bp(app, handler): - bp = Blueprint(__name__, version=1) + bp = Blueprint("Test", version=1) bp.route("/")(handler) app.blueprint(bp) @@ -28,7 +28,7 @@ def test_bp(app, handler): def test_bp_use_route(app, handler): - bp = Blueprint(__name__, version=1) + bp = Blueprint("Test", version=1) bp.route("/", version=1.1)(handler) app.blueprint(bp) @@ -37,7 +37,7 @@ def test_bp_use_route(app, handler): def test_bp_group(app, handler): - bp = Blueprint(__name__) + bp = Blueprint("Test") bp.route("/")(handler) group = Blueprint.group(bp, version=1) app.blueprint(group) @@ -47,7 +47,7 @@ def test_bp_group(app, handler): def test_bp_group_use_bp(app, handler): - bp = Blueprint(__name__, version=1.1) + bp = Blueprint("Test", version=1.1) bp.route("/")(handler) group = Blueprint.group(bp, version=1) app.blueprint(group) @@ -57,7 +57,7 @@ def test_bp_group_use_bp(app, handler): def test_bp_group_use_registration(app, handler): - bp = Blueprint(__name__, version=1.1) + bp = Blueprint("Test", version=1.1) bp.route("/")(handler) group = Blueprint.group(bp, version=1) app.blueprint(group, version=1.2) @@ -67,7 +67,7 @@ def test_bp_group_use_registration(app, handler): def test_bp_group_use_route(app, handler): - bp = Blueprint(__name__, version=1.1) + bp = Blueprint("Test", version=1.1) bp.route("/", version=1.3)(handler) group = Blueprint.group(bp, version=1) app.blueprint(group, version=1.2) @@ -84,7 +84,7 @@ def test_version_prefix_route(app, handler): def test_version_prefix_bp(app, handler): - bp = Blueprint(__name__, version=1, version_prefix="/api/v") + bp = Blueprint("Test", version=1, version_prefix="/api/v") bp.route("/")(handler) app.blueprint(bp) @@ -93,7 +93,7 @@ def test_version_prefix_bp(app, handler): def test_version_prefix_bp_use_route(app, handler): - bp = Blueprint(__name__, version=1, version_prefix="/ignore/v") + bp = Blueprint("Test", version=1, version_prefix="/ignore/v") bp.route("/", version=1.1, version_prefix="/api/v")(handler) app.blueprint(bp) @@ -102,7 +102,7 @@ def test_version_prefix_bp_use_route(app, handler): def test_version_prefix_bp_group(app, handler): - bp = Blueprint(__name__) + bp = Blueprint("Test") bp.route("/")(handler) group = Blueprint.group(bp, version=1, version_prefix="/api/v") app.blueprint(group) @@ -112,7 +112,7 @@ def test_version_prefix_bp_group(app, handler): def test_version_prefix_bp_group_use_bp(app, handler): - bp = Blueprint(__name__, version=1.1, version_prefix="/api/v") + bp = Blueprint("Test", version=1.1, version_prefix="/api/v") bp.route("/")(handler) group = Blueprint.group(bp, version=1, version_prefix="/ignore/v") app.blueprint(group) @@ -122,7 +122,7 @@ def test_version_prefix_bp_group_use_bp(app, handler): def test_version_prefix_bp_group_use_registration(app, handler): - bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v") + bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v") bp.route("/")(handler) group = Blueprint.group(bp, version=1, version_prefix="/ignore/v") app.blueprint(group, version=1.2, version_prefix="/api/v") @@ -132,7 +132,7 @@ def test_version_prefix_bp_group_use_registration(app, handler): def test_version_prefix_bp_group_use_route(app, handler): - bp = Blueprint(__name__, version=1.1, version_prefix="/alsoignore/v") + bp = Blueprint("Test", version=1.1, version_prefix="/alsoignore/v") bp.route("/", version=1.3, version_prefix="/api/v")(handler) group = Blueprint.group(bp, version=1, version_prefix="/ignore/v") app.blueprint(group, version=1.2, version_prefix="/stillignoring/v") diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 329eff455f..dd8413b981 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -14,7 +14,7 @@ try: from unittest.mock import AsyncMock except ImportError: - from asyncmock import AsyncMock # type: ignore + from tests.asyncmock import AsyncMock # type: ignore @pytest.mark.asyncio diff --git a/tox.ini b/tox.ini index 6c4bbdbe7a..2a044f9ea3 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,9 @@ usedevelop = true setenv = {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UJSON=1 {py37,py38,py39,py310,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 -extras = test +extras = test, http3 +deps = + httpx==0.23 allowlist_externals = pytest coverage @@ -46,7 +48,7 @@ commands = [testenv:docs] platform = linux|linux2|darwin allowlist_externals = make -extras = docs +extras = docs, http3 commands = make docs-test