diff --git a/sanic/app.py b/sanic/app.py index 41f065800f..c951009b7d 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -140,6 +140,7 @@ class Sanic(BaseSanic, StartupMixin, metaclass=TouchUpMeta): "configure_logging", "ctx", "error_handler", + "inspector_class", "go_fast", "listeners", "multiplexer", @@ -176,6 +177,7 @@ def __init__( dumps: Optional[Callable[..., AnyStr]] = None, loads: Optional[Callable[..., Any]] = None, inspector: bool = False, + inspector_class: Optional[Type[Inspector]] = None, ) -> None: super().__init__(name=name) # logging @@ -211,6 +213,7 @@ def __init__( self.configure_logging: bool = configure_logging self.ctx: Any = ctx or SimpleNamespace() self.error_handler: ErrorHandler = error_handler or ErrorHandler() + self.inspector_class: Type[Inspector] = inspector_class or Inspector self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} diff --git a/sanic/cli/app.py b/sanic/cli/app.py index 24825e25de..444de0f688 100644 --- a/sanic/cli/app.py +++ b/sanic/cli/app.py @@ -3,23 +3,21 @@ import shutil import sys -from argparse import ArgumentParser, RawTextHelpFormatter +from argparse import Namespace from functools import partial from textwrap import indent -from typing import Any, List, Union +from typing import List, Union, cast from sanic.app import Sanic from sanic.application.logo import get_logo from sanic.cli.arguments import Group -from sanic.log import error_logger -from sanic.worker.inspector import inspect +from sanic.cli.base import SanicArgumentParser, SanicHelpFormatter +from sanic.cli.inspector import make_inspector_parser +from sanic.cli.inspector_client import InspectorClient +from sanic.log import Colors, error_logger from sanic.worker.loader import AppLoader -class SanicArgumentParser(ArgumentParser): - ... - - class SanicCLI: DESCRIPTION = indent( f""" @@ -46,7 +44,7 @@ def __init__(self) -> None: self.parser = SanicArgumentParser( prog="sanic", description=self.DESCRIPTION, - formatter_class=lambda prog: RawTextHelpFormatter( + formatter_class=lambda prog: SanicHelpFormatter( prog, max_help_position=36 if width > 96 else 24, indent_increment=4, @@ -58,16 +56,27 @@ def __init__(self) -> None: self.main_process = ( os.environ.get("SANIC_RELOADER_PROCESS", "") != "true" ) - self.args: List[Any] = [] + self.args: Namespace = Namespace() self.groups: List[Group] = [] + self.inspecting = False def attach(self): + if sys.argv[1] == "inspect": + self.inspecting = True + self.parser.description = get_logo(True) + make_inspector_parser(self.parser) + return + for group in Group._registry: instance = group.create(self.parser) instance.attach() self.groups.append(instance) def run(self, parse_args=None): + if self.inspecting: + self._inspector() + return + legacy_version = False if not parse_args: # This is to provide backwards compat -v to display version @@ -86,52 +95,21 @@ def run(self, parse_args=None): self.args = self.parser.parse_args(args=parse_args) self._precheck() app_loader = AppLoader( - self.args.module, - self.args.factory, - self.args.simple, - self.args, + self.args.module, self.args.factory, self.args.simple, self.args ) + if self.args.inspect or self.args.inspect_raw or self.args.trigger: + self._inspector_legacy(app_loader) + return + try: app = self._get_app(app_loader) kwargs = self._build_run_kwargs() except ValueError as e: error_logger.exception(f"Failed to run app: {e}") else: - if ( - self.args.inspect - or self.args.inspect_raw - or self.args.trigger - or self.args.scale is not None - ): - os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" - else: - for http_version in self.args.http: - app.prepare(**kwargs, version=http_version) - - if ( - self.args.inspect - or self.args.inspect_raw - or self.args.trigger - or self.args.scale is not None - ): - if self.args.scale is not None: - if self.args.scale <= 0: - error_logger.error("There must be at least 1 worker") - sys.exit(1) - action = f"scale={self.args.scale}" - else: - action = self.args.trigger or ( - "raw" if self.args.inspect_raw else "pretty" - ) - inspect( - app.config.INSPECTOR_HOST, - app.config.INSPECTOR_PORT, - action, - ) - del os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] - return - + for http_version in self.args.http: + app.prepare(**kwargs, version=http_version) if self.args.single: serve = Sanic.serve_single elif self.args.legacy: @@ -140,6 +118,53 @@ def run(self, parse_args=None): serve = partial(Sanic.serve, app_loader=app_loader) serve(app) + def _inspector_legacy(self, app_loader: AppLoader): + host = port = None + module = cast(str, self.args.module) + if ":" in module: + maybe_host, maybe_port = module.rsplit(":", 1) + if maybe_port.isnumeric(): + host, port = maybe_host, int(maybe_port) + if not host: + app = self._get_app(app_loader) + host, port = app.config.INSPECTOR_HOST, app.config.INSPECTOR_PORT + + action = self.args.trigger or "info" + + InspectorClient( + str(host), int(port or 6457), False, self.args.inspect_raw, "" + ).do(action) + sys.stdout.write( + f"\n{Colors.BOLD}{Colors.YELLOW}WARNING:{Colors.END} " + "You are using the legacy CLI command that will be removed in " + f"{Colors.RED}v23.3{Colors.END}. See ___ or checkout the new " + "style commands:\n\n\t" + f"{Colors.YELLOW}sanic inspect --help{Colors.END}\n" + ) + + def _inspector(self): + args = sys.argv[2:] + self.args, unknown = self.parser.parse_known_args(args=args) + if unknown: + for arg in unknown: + if arg.startswith("--"): + key, value = arg.split("=") + setattr(self.args, key.lstrip("-"), value) + + kwargs = {**self.args.__dict__} + host = kwargs.pop("host") + port = kwargs.pop("port") + secure = kwargs.pop("secure") + raw = kwargs.pop("raw") + action = kwargs.pop("action") or "info" + api_key = kwargs.pop("api_key") + positional = kwargs.pop("positional", None) + if action == "" and positional: + action = positional[0] + if len(positional) > 1: + kwargs["args"] = positional[1:] + InspectorClient(host, port, secure, raw, api_key).do(action, **kwargs) + def _precheck(self): # Custom TLS mismatch handling for better diagnostics if self.main_process and ( diff --git a/sanic/cli/arguments.py b/sanic/cli/arguments.py index f01dda9b56..e1fe905adf 100644 --- a/sanic/cli/arguments.py +++ b/sanic/cli/arguments.py @@ -115,12 +115,6 @@ def attach(self): const="shutdown", help=("Trigger all processes to shutdown"), ) - group.add_argument( - "--scale", - dest="scale", - type=int, - help=("Scale number of workers"), - ) class HTTPVersionGroup(Group): diff --git a/sanic/cli/base.py b/sanic/cli/base.py new file mode 100644 index 0000000000..270c4118b2 --- /dev/null +++ b/sanic/cli/base.py @@ -0,0 +1,35 @@ +from argparse import ( + SUPPRESS, + Action, + ArgumentParser, + RawTextHelpFormatter, + _SubParsersAction, +) +from typing import Any + + +class SanicArgumentParser(ArgumentParser): + def _check_value(self, action: Action, value: Any) -> None: + if isinstance(action, SanicSubParsersAction): + return + super()._check_value(action, value) + + +class SanicHelpFormatter(RawTextHelpFormatter): + def add_usage(self, usage, actions, groups, prefix=None): + if not usage: + usage = SUPPRESS + # Add one linebreak, but not two + self.add_text("\x1b[1A'") + super().add_usage(usage, actions, groups, prefix) + + +class SanicSubParsersAction(_SubParsersAction): + def __call__(self, parser, namespace, values, option_string=None): + self._name_parser_map + parser_name = values[0] + if parser_name not in self._name_parser_map: + self._name_parser_map[parser_name] = parser + values = ["", *values] + + super().__call__(parser, namespace, values, option_string) diff --git a/sanic/cli/inspector.py b/sanic/cli/inspector.py new file mode 100644 index 0000000000..5a1719b49f --- /dev/null +++ b/sanic/cli/inspector.py @@ -0,0 +1,72 @@ +from argparse import ArgumentParser + +from sanic.application.logo import get_logo +from sanic.cli.base import SanicHelpFormatter, SanicSubParsersAction + + +def _add_shared(parser: ArgumentParser) -> None: + parser.add_argument("--host", "-H", default="localhost") + parser.add_argument("--port", "-p", default=6457, type=int) + parser.add_argument("--secure", "-s", action="store_true") + parser.add_argument("--api-key", "-k") + parser.add_argument( + "--raw", + action="store_true", + help="Whether to output the raw response information", + ) + + +class InspectorSubParser(ArgumentParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _add_shared(self) + if not self.description: + self.description = "" + self.description = get_logo(True) + self.description + + +def make_inspector_parser(parser: ArgumentParser) -> None: + _add_shared(parser) + subparsers = parser.add_subparsers( + action=SanicSubParsersAction, + dest="action", + description=( + "Run one of the below subcommands. If you have created a custom " + "Inspector instance, then you can run custom commands.\nSee ___ " + "for more details." + ), + title="Required\n========\n Subcommands", + parser_class=InspectorSubParser, + ) + subparsers.add_parser( + "reload", + help="Trigger a reload of the server workers", + formatter_class=SanicHelpFormatter, + ) + subparsers.add_parser( + "shutdown", + help="Shutdown the application and all processes", + formatter_class=SanicHelpFormatter, + ) + scale = subparsers.add_parser( + "scale", + help="Scale the number of workers", + formatter_class=SanicHelpFormatter, + ) + scale.add_argument("replicas", type=int) + + custom = subparsers.add_parser( + "", + help="Run a custom command", + description=( + "keyword arguments:\n When running a custom command, you can " + "add keyword arguments by appending them to your command\n\n" + "\tsanic inspect foo --one=1 --two=2" + ), + formatter_class=SanicHelpFormatter, + ) + custom.add_argument( + "positional", + nargs="*", + help="Add one or more non-keyword args to your custom command", + ) diff --git a/sanic/cli/inspector_client.py b/sanic/cli/inspector_client.py new file mode 100644 index 0000000000..fd22bbd869 --- /dev/null +++ b/sanic/cli/inspector_client.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import sys + +from http.client import RemoteDisconnected +from textwrap import indent +from typing import Any, Dict, Optional +from urllib.error import URLError +from urllib.request import Request as URequest +from urllib.request import urlopen + +from sanic.application.logo import get_logo +from sanic.application.motd import MOTDTTY +from sanic.log import Colors + + +try: # no cov + from ujson import dumps, loads +except ModuleNotFoundError: # no cov + from json import dumps, loads # type: ignore + + +class InspectorClient: + def __init__( + self, + host: str, + port: int, + secure: bool, + raw: bool, + api_key: Optional[str], + ) -> None: + self.scheme = "https" if secure else "http" + self.host = host + self.port = port + self.raw = raw + self.api_key = api_key + + for scheme in ("http", "https"): + full = f"{scheme}://" + if self.host.startswith(full): + self.scheme = scheme + self.host = self.host[len(full) :] # noqa E203 + + def do(self, action: str, **kwargs: Any) -> None: + if action == "info": + self.info() + return + result = self.request(action, **kwargs).get("result") + if result: + out = ( + dumps(result) + if isinstance(result, (list, dict)) + else str(result) + ) + sys.stdout.write(out + "\n") + + def info(self) -> None: + out = sys.stdout.write + response = self.request("", "GET") + if self.raw or not response: + return + data = response["result"] + display = data.pop("info") + extra = display.pop("extra", {}) + display["packages"] = ", ".join(display["packages"]) + MOTDTTY(get_logo(), self.base_url, display, extra).display( + version=False, + action="Inspecting", + out=out, + ) + for name, info in data["workers"].items(): + info = "\n".join( + f"\t{key}: {Colors.BLUE}{value}{Colors.END}" + for key, value in info.items() + ) + out( + "\n" + + indent( + "\n".join( + [ + f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}", + info, + ] + ), + " ", + ) + + "\n" + ) + + def request(self, action: str, method: str = "POST", **kwargs: Any) -> Any: + url = f"{self.base_url}/{action}" + params: Dict[str, Any] = {"method": method, "headers": {}} + if kwargs: + params["data"] = dumps(kwargs).encode() + params["headers"]["content-type"] = "application/json" + if self.api_key: + params["headers"]["authorization"] = f"Bearer {self.api_key}" + request = URequest(url, **params) + + try: + with urlopen(request) as response: # nosec B310 + raw = response.read() + loaded = loads(raw) + if self.raw: + sys.stdout.write(dumps(loaded.get("result")) + "\n") + return {} + return loaded + except (URLError, RemoteDisconnected) as e: + sys.stderr.write( + f"{Colors.RED}Could not connect to inspector at: " + f"{Colors.YELLOW}{self.base_url}{Colors.END}\n" + "Either the application is not running, or it did not start " + f"an inspector instance.\n{e}\n" + ) + sys.exit(1) + + @property + def base_url(self): + return f"{self.scheme}://{self.host}:{self.port}" diff --git a/sanic/config.py b/sanic/config.py index dc14d71017..0d5eabf80b 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -46,6 +46,9 @@ "INSPECTOR": False, "INSPECTOR_HOST": "localhost", "INSPECTOR_PORT": 6457, + "INSPECTOR_TLS_KEY": _default, + "INSPECTOR_TLS_CERT": _default, + "INSPECTOR_API_KEY": "", "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds "KEEP_ALIVE": True, "LOCAL_CERT_CREATOR": LocalCertCreator.AUTO, @@ -93,6 +96,9 @@ class Config(dict, metaclass=DescriptorMeta): INSPECTOR: bool INSPECTOR_HOST: str INSPECTOR_PORT: int + INSPECTOR_TLS_KEY: Union[Path, str, Default] + INSPECTOR_TLS_CERT: Union[Path, str, Default] + INSPECTOR_API_KEY: str KEEP_ALIVE_TIMEOUT: int KEEP_ALIVE: bool LOCAL_CERT_CREATOR: Union[str, LocalCertCreator] diff --git a/sanic/http/tls/context.py b/sanic/http/tls/context.py index f77fa56051..98c090bb34 100644 --- a/sanic/http/tls/context.py +++ b/sanic/http/tls/context.py @@ -24,13 +24,15 @@ def create_context( certfile: Optional[str] = None, keyfile: Optional[str] = None, password: Optional[str] = None, + purpose: ssl.Purpose = ssl.Purpose.CLIENT_AUTH, ) -> ssl.SSLContext: """Create a context with secure crypto and HTTP/1.1 in protocols.""" - context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + context = ssl.create_default_context(purpose=purpose) context.minimum_version = ssl.TLSVersion.TLSv1_2 context.set_ciphers(":".join(CIPHERS_TLS12)) context.set_alpn_protocols(["http/1.1"]) - context.sni_callback = server_name_callback + if purpose is ssl.Purpose.CLIENT_AUTH: + context.sni_callback = server_name_callback if certfile and keyfile: context.load_cert_chain(certfile, keyfile, password) return context diff --git a/sanic/mixins/startup.py b/sanic/mixins/startup.py index 140ecd2258..860d34a090 100644 --- a/sanic/mixins/startup.py +++ b/sanic/mixins/startup.py @@ -27,6 +27,7 @@ Callable, Dict, List, + Mapping, Optional, Set, Tuple, @@ -58,7 +59,6 @@ from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.runners import serve, serve_multiple, serve_single from sanic.server.socket import configure_socket, remove_unix_socket -from sanic.worker.inspector import Inspector from sanic.worker.loader import AppLoader from sanic.worker.manager import WorkerManager from sanic.worker.multiplexer import WorkerMultiplexer @@ -126,7 +126,7 @@ def run( register_sys_signals: bool = True, access_log: Optional[bool] = None, unix: Optional[str] = None, - loop: AbstractEventLoop = None, + loop: Optional[AbstractEventLoop] = None, reload_dir: Optional[Union[List[str], str]] = None, noisy_exceptions: Optional[bool] = None, motd: bool = True, @@ -225,7 +225,7 @@ def prepare( register_sys_signals: bool = True, access_log: Optional[bool] = None, unix: Optional[str] = None, - loop: AbstractEventLoop = None, + loop: Optional[AbstractEventLoop] = None, reload_dir: Optional[Union[List[str], str]] = None, noisy_exceptions: Optional[bool] = None, motd: bool = True, @@ -355,12 +355,12 @@ async def create_server( debug: bool = False, ssl: Union[None, SSLContext, dict, str, list, tuple] = None, sock: Optional[socket] = None, - protocol: Type[Protocol] = None, + protocol: Optional[Type[Protocol]] = None, backlog: int = 100, access_log: Optional[bool] = None, unix: Optional[str] = None, return_asyncio_server: bool = False, - asyncio_server_kwargs: Dict[str, Any] = None, + asyncio_server_kwargs: Optional[Dict[str, Any]] = None, noisy_exceptions: Optional[bool] = None, ) -> Optional[AsyncioServer]: """ @@ -481,7 +481,7 @@ def _helper( sock: Optional[socket] = None, unix: Optional[str] = None, workers: int = 1, - loop: AbstractEventLoop = None, + loop: Optional[AbstractEventLoop] = None, protocol: Type[Protocol] = HttpProtocol, backlog: int = 100, register_sys_signals: bool = True, @@ -769,7 +769,7 @@ def serve( ] primary_server_info.settings["run_multiple"] = True monitor_sub, monitor_pub = Pipe(True) - worker_state: Dict[str, Any] = sync_manager.dict() + worker_state: Mapping[str, Any] = sync_manager.dict() kwargs: Dict[str, Any] = { **primary_server_info.settings, "monitor_publisher": monitor_pub, @@ -841,14 +841,17 @@ def serve( "packages": [sanic_version, *packages], "extra": extra, } - inspector = Inspector( + inspector = primary.inspector_class( monitor_pub, app_info, worker_state, primary.config.INSPECTOR_HOST, primary.config.INSPECTOR_PORT, + primary.config.INSPECTOR_API_KEY, + primary.config.INSPECTOR_TLS_KEY, + primary.config.INSPECTOR_TLS_CERT, ) - manager.manage("Inspector", inspector, {}, transient=False) + manager.manage("Inspector", inspector, {}, transient=True) primary._inspector = inspector primary._manager = manager diff --git a/sanic/worker/inspector.py b/sanic/worker/inspector.py index 769bc784fd..d60c5ae3e7 100644 --- a/sanic/worker/inspector.py +++ b/sanic/worker/inspector.py @@ -1,23 +1,17 @@ -import sys +from __future__ import annotations from datetime import datetime +from inspect import isawaitable from multiprocessing.connection import Connection -from signal import SIGINT, SIGTERM -from signal import signal as signal_func -from socket import AF_INET, SOCK_STREAM, socket, timeout -from textwrap import indent -from typing import Any, Dict +from os import environ +from pathlib import Path +from typing import Any, Dict, Mapping, Union -from sanic.application.logo import get_logo -from sanic.application.motd import MOTDTTY -from sanic.log import Colors, error_logger, logger -from sanic.server.socket import configure_socket - - -try: # no cov - from ujson import dumps, loads -except ModuleNotFoundError: # no cov - from json import dumps, loads # type: ignore +from sanic.exceptions import Unauthorized +from sanic.helpers import Default +from sanic.log import logger +from sanic.request import Request +from sanic.response import json class Inspector: @@ -25,125 +19,102 @@ def __init__( self, publisher: Connection, app_info: Dict[str, Any], - worker_state: Dict[str, Any], + worker_state: Mapping[str, Any], host: str, port: int, + api_key: str, + tls_key: Union[Path, str, Default], + tls_cert: Union[Path, str, Default], ): self._publisher = publisher - self.run = True self.app_info = app_info self.worker_state = worker_state self.host = host self.port = port - - def __call__(self) -> None: - sock = configure_socket( - {"host": self.host, "port": self.port, "unix": None, "backlog": 1} + self.api_key = api_key + self.tls_key = tls_key + self.tls_cert = tls_cert + + def __call__(self, run=True, **_) -> Inspector: + from sanic import Sanic + + self.app = Sanic("Inspector") + self._setup() + if run: + self.app.run( + host=self.host, + port=self.port, + single_process=True, + ssl={"key": self.tls_key, "cert": self.tls_cert} + if not isinstance(self.tls_key, Default) + and not isinstance(self.tls_cert, Default) + else None, + ) + return self + + def _setup(self): + self.app.get("/")(self._info) + self.app.post("/")(self._action) + if self.api_key: + self.app.on_request(self._authentication) + environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" + + def _authentication(self, request: Request) -> None: + if request.token != self.api_key: + raise Unauthorized("Bad API key") + + async def _action(self, request: Request, action: str): + logger.info("Incoming inspector action: %s", action) + output: Any = None + method = getattr(self, action, None) + if method: + kwargs = {} + if request.body: + kwargs = request.json + output = method(**kwargs) + if isawaitable(output): + output = await output + + return await self._respond(request, output) + + async def _info(self, request: Request): + return await self._respond(request, self._state_to_json()) + + async def _respond(self, request: Request, output: Any): + name = request.match_info.get("action", "info") + return json( + {"meta": {"action": name}, "result": output}, + escape_forward_slashes=False, ) - assert sock - signal_func(SIGINT, self.stop) - signal_func(SIGTERM, self.stop) - - logger.info(f"Inspector started on: {sock.getsockname()}") - sock.settimeout(0.5) - try: - while self.run: - try: - conn, _ = sock.accept() - except timeout: - continue - else: - action = conn.recv(64) - if action == b"reload": - self.reload() - elif action == b"shutdown": - self.shutdown() - elif action.startswith(b"scale"): - num_workers = int(action.split(b"=", 1)[-1]) - logger.info("Scaling to %s", num_workers) - self.scale(num_workers) - else: - data = dumps(self.state_to_json()) - conn.send(data.encode()) - conn.send(b"\n") - conn.close() - finally: - logger.info("Inspector closing") - sock.close() - - def stop(self, *_): - self.run = False - - def state_to_json(self): + + def _state_to_json(self) -> Dict[str, Any]: output = {"info": self.app_info} - output["workers"] = self.make_safe(dict(self.worker_state)) + output["workers"] = self._make_safe(dict(self.worker_state)) return output - def reload(self): - message = "__ALL_PROCESSES__:" - self._publisher.send(message) - - def scale(self, num_workers: int): - message = f"__SCALE__:{num_workers}" - self._publisher.send(message) - - def shutdown(self): - message = "__TERMINATE__" - self._publisher.send(message) - @staticmethod - def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]: + def _make_safe(obj: Dict[str, Any]) -> Dict[str, Any]: for key, value in obj.items(): if isinstance(value, dict): - obj[key] = Inspector.make_safe(value) + obj[key] = Inspector._make_safe(value) elif isinstance(value, datetime): obj[key] = value.isoformat() return obj + def reload(self) -> None: + message = "__ALL_PROCESSES__:" + self._publisher.send(message) -def inspect(host: str, port: int, action: str): - out = sys.stdout.write - with socket(AF_INET, SOCK_STREAM) as sock: - try: - sock.connect((host, port)) - except ConnectionRefusedError: - error_logger.error( - f"{Colors.RED}Could not connect to inspector at: " - f"{Colors.YELLOW}{(host, port)}{Colors.END}\n" - "Either the application is not running, or it did not start " - "an inspector instance." - ) - sock.close() - sys.exit(1) - sock.sendall(action.encode()) - data = sock.recv(4096) - if action == "raw": - out(data.decode()) - elif action == "pretty": - loaded = loads(data) - display = loaded.pop("info") - extra = display.pop("extra", {}) - display["packages"] = ", ".join(display["packages"]) - MOTDTTY(get_logo(), f"{host}:{port}", display, extra).display( - version=False, - action="Inspecting", - out=out, - ) - for name, info in loaded["workers"].items(): - info = "\n".join( - f"\t{key}: {Colors.BLUE}{value}{Colors.END}" - for key, value in info.items() - ) - out( - "\n" - + indent( - "\n".join( - [ - f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}", - info, - ] - ), - " ", - ) - + "\n" - ) + def scale(self, replicas) -> str: + num_workers = 1 + if replicas: + num_workers = int(replicas) + log_msg = f"Scaling to {num_workers}" + logger.info(log_msg) + message = f"__SCALE__:{num_workers}" + self._publisher.send(message) + return log_msg + + def shutdown(self) -> None: + message = "__TERMINATE__" + self._publisher.send(message) diff --git a/sanic/worker/serve.py b/sanic/worker/serve.py index 8d233069ec..39c647b2b2 100644 --- a/sanic/worker/serve.py +++ b/sanic/worker/serve.py @@ -17,6 +17,7 @@ from sanic.server.runners import _serve_http_1, _serve_http_3 from sanic.worker.loader import AppLoader, CertLoader from sanic.worker.multiplexer import WorkerMultiplexer +from sanic.worker.process import Worker, WorkerProcess def worker_serve( @@ -79,7 +80,10 @@ def worker_serve( info.settings["ssl"] = ssl # When in a worker process, do some init - if os.environ.get("SANIC_WORKER_NAME"): + worker_name = os.environ.get("SANIC_WORKER_NAME") + if worker_name and worker_name.startswith( + Worker.WORKER_PREFIX + WorkerProcess.SERVER_LABEL + ): # Hydrate apps with any passed server info if monitor_publisher is None: diff --git a/tests/conftest.py b/tests/conftest.py index a84c34dc7e..22082fdffc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from contextlib import suppress from logging import LogRecord from typing import Any, Dict, List, Tuple -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock, patch import pytest @@ -221,3 +221,14 @@ def sanic_ext(ext_instance): # noqa yield sanic_ext with suppress(KeyError): del sys.modules["sanic_ext"] + + +@pytest.fixture +def urlopen(): + urlopen = Mock() + urlopen.return_value = urlopen + urlopen.__enter__ = Mock(return_value=urlopen) + urlopen.__exit__ = Mock() + urlopen.read = Mock() + with patch("sanic.cli.inspector_client.urlopen", urlopen): + yield urlopen diff --git a/tests/test_cli.py b/tests/test_cli.py index 47979fa264..c8e0c79ba9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Optional, Tuple +from unittest.mock import patch import pytest @@ -11,6 +12,7 @@ from sanic import __version__ from sanic.__main__ import main +from sanic.cli.inspector_client import InspectorClient @pytest.fixture(scope="module", autouse=True) @@ -292,3 +294,47 @@ def test_noisy_exceptions(cmd: str, expected: bool, caplog): info = read_app_info(lines) assert info["noisy_exceptions"] is expected + + +def test_inspector_inspect(urlopen, caplog, capsys): + urlopen.read.return_value = json.dumps( + { + "result": { + "info": { + "packages": ["foo"], + }, + "extra": { + "more": "data", + }, + "workers": {"Worker-Name": {"some": "state"}}, + } + } + ).encode() + with patch("sys.argv", ["sanic", "inspect"]): + capture(["inspect"], caplog) + captured = capsys.readouterr() + assert "Inspecting @ http://localhost:6457" in captured.out + assert "Worker-Name" in captured.out + assert captured.err == "" + + +@pytest.mark.parametrize( + "command,params", + ( + (["reload"], {}), + (["shutdown"], {}), + (["scale", "9"], {"replicas": 9}), + (["foo", "--bar=something"], {"bar": "something"}), + (["foo", "positional"], {"args": ["positional"]}), + ( + ["foo", "positional", "--bar=something"], + {"args": ["positional"], "bar": "something"}, + ), + ), +) +def test_inspector_command(command, params): + with patch.object(InspectorClient, "request") as client: + with patch("sys.argv", ["sanic", "inspect", *command]): + main() + + client.assert_called_once_with(command[0], **params) diff --git a/tests/worker/test_inspector.py b/tests/worker/test_inspector.py index 0c9eb90654..bc9dd2d527 100644 --- a/tests/worker/test_inspector.py +++ b/tests/worker/test_inspector.py @@ -1,14 +1,20 @@ -import json +try: # no cov + from ujson import dumps +except ModuleNotFoundError: # no cov + from json import dumps # type: ignore from datetime import datetime -from logging import ERROR, INFO -from socket import AF_INET, SOCK_STREAM, timeout from unittest.mock import Mock, patch +from urllib.error import URLError import pytest +from sanic_testing import TestManager + +from sanic.cli.inspector_client import InspectorClient +from sanic.helpers import Default from sanic.log import Colors -from sanic.worker.inspector import Inspector, inspect +from sanic.worker.inspector import Inspector DATA = { @@ -20,130 +26,84 @@ }, "workers": {"Worker-Name": {"some": "state"}}, } -SERIALIZED = json.dumps(DATA) - - -def test_inspector_stop(): - inspector = Inspector(Mock(), {}, {}, "", 1) - assert inspector.run is True - inspector.stop() - assert inspector.run is False - - -@patch("sanic.worker.inspector.sys.stdout.write") -@patch("sanic.worker.inspector.socket") -@pytest.mark.parametrize("command", ("foo", "raw", "pretty")) -def test_send_inspect(socket: Mock, write: Mock, command: str): - socket.return_value = socket - socket.__enter__.return_value = socket - socket.recv.return_value = SERIALIZED.encode() - inspect("localhost", 9999, command) - - socket.sendall.assert_called_once_with(command.encode()) - socket.recv.assert_called_once_with(4096) - socket.connect.assert_called_once_with(("localhost", 9999)) - socket.assert_called_once_with(AF_INET, SOCK_STREAM) - - if command == "raw": - write.assert_called_once_with(SERIALIZED) - elif command == "pretty": - write.assert_called() - else: - write.assert_not_called() - - -@patch("sanic.worker.inspector.sys") -@patch("sanic.worker.inspector.socket") -def test_send_inspect_conn_refused(socket: Mock, sys: Mock, caplog): - with caplog.at_level(INFO): - socket.return_value = socket - socket.__enter__.return_value = socket - socket.connect.side_effect = ConnectionRefusedError() - inspect("localhost", 9999, "foo") - - socket.close.assert_called_once() - sys.exit.assert_called_once_with(1) +FULL_SERIALIZED = dumps({"result": DATA}) +OUT_SERIALIZED = dumps(DATA) + + +class FooInspector(Inspector): + async def foo(self, bar): + return f"bar is {bar}" + + +@pytest.fixture +def publisher(): + publisher = Mock() + return publisher + + +@pytest.fixture +def inspector(publisher): + inspector = FooInspector( + publisher, {}, {}, "localhost", 9999, "", Default(), Default() + ) + inspector(False) + return inspector + + +@pytest.fixture +def http_client(inspector): + manager = TestManager(inspector.app) + return manager.test_client + + +@pytest.mark.parametrize("command", ("info",)) +@patch("sanic.cli.inspector_client.sys.stdout.write") +def test_send_inspect(write, urlopen, command: str): + urlopen.read.return_value = FULL_SERIALIZED.encode() + InspectorClient("localhost", 9999, False, False, None).do(command) + write.assert_called() + write.reset_mock() + InspectorClient("localhost", 9999, False, True, None).do(command) + write.assert_called_with(OUT_SERIALIZED + "\n") + + +@patch("sanic.cli.inspector_client.sys") +def test_send_inspect_conn_refused(sys: Mock, urlopen): + urlopen.side_effect = URLError("") + InspectorClient("localhost", 9999, False, False, None).do("info") message = ( f"{Colors.RED}Could not connect to inspector at: " - f"{Colors.YELLOW}('localhost', 9999){Colors.END}\n" + f"{Colors.YELLOW}http://localhost:9999{Colors.END}\n" "Either the application is not running, or it did not start " - "an inspector instance." + "an inspector instance.\n\n" ) - assert ("sanic.error", ERROR, message) in caplog.record_tuples - - -@patch("sanic.worker.inspector.configure_socket") -@pytest.mark.parametrize( - "action", (b"reload", b"shutdown", b"scale=5", b"foo") -) -def test_run_inspector(configure_socket: Mock, action: bytes): - sock = Mock() - conn = Mock() - conn.recv.return_value = action - configure_socket.return_value = sock - inspector = Inspector(Mock(), {}, {}, "localhost", 9999) - inspector.reload = Mock() # type: ignore - inspector.shutdown = Mock() # type: ignore - inspector.scale = Mock() # type: ignore - inspector.state_to_json = Mock(return_value="foo") # type: ignore - - def accept(): - inspector.run = False - return conn, ... - - sock.accept = accept - - inspector() - - configure_socket.assert_called_once_with( - {"host": "localhost", "port": 9999, "unix": None, "backlog": 1} - ) - conn.recv.assert_called_with(64) - - conn.send.assert_called_with(b"\n") - if action == b"reload": - inspector.reload.assert_called() - inspector.shutdown.assert_not_called() - inspector.scale.assert_not_called() - inspector.state_to_json.assert_not_called() - elif action == b"shutdown": - inspector.reload.assert_not_called() - inspector.shutdown.assert_called() - inspector.scale.assert_not_called() - inspector.state_to_json.assert_not_called() - elif action.startswith(b"scale"): - inspector.reload.assert_not_called() - inspector.shutdown.assert_not_called() - inspector.scale.assert_called_once_with(5) - inspector.state_to_json.assert_not_called() - else: - inspector.reload.assert_not_called() - inspector.shutdown.assert_not_called() - inspector.scale.assert_not_called() - inspector.state_to_json.assert_called() - - -@patch("sanic.worker.inspector.configure_socket") -def test_accept_timeout(configure_socket: Mock): - sock = Mock() - configure_socket.return_value = sock - inspector = Inspector(Mock(), {}, {}, "localhost", 9999) - inspector.reload = Mock() # type: ignore - inspector.shutdown = Mock() # type: ignore - inspector.state_to_json = Mock(return_value="foo") # type: ignore - - def accept(): - inspector.run = False - raise timeout - - sock.accept = accept - - inspector() - - inspector.reload.assert_not_called() - inspector.shutdown.assert_not_called() - inspector.state_to_json.assert_not_called() + sys.exit.assert_called_once_with(1) + sys.stderr.write.assert_called_once_with(message) + + +def test_run_inspector_reload(publisher, http_client): + _, response = http_client.post("/reload") + assert response.status == 200 + publisher.send.assert_called_once_with("__ALL_PROCESSES__:") + + +def test_run_inspector_shutdown(publisher, http_client): + _, response = http_client.post("/shutdown") + assert response.status == 200 + publisher.send.assert_called_once_with("__TERMINATE__") + + +def test_run_inspector_scale(publisher, http_client): + _, response = http_client.post("/scale", json={"replicas": 4}) + assert response.status == 200 + publisher.send.assert_called_once_with("__SCALE__:4") + + +def test_run_inspector_arbitrary(http_client): + _, response = http_client.post("/foo", json={"bar": 99}) + assert response.status == 200 + assert response.json == {"meta": {"action": "foo"}, "result": "bar is 99"} def test_state_to_json(): @@ -151,8 +111,10 @@ def test_state_to_json(): now_iso = now.isoformat() app_info = {"app": "hello"} worker_state = {"Test": {"now": now, "nested": {"foo": now}}} - inspector = Inspector(Mock(), app_info, worker_state, "", 0) - state = inspector.state_to_json() + inspector = Inspector( + Mock(), app_info, worker_state, "", 0, "", Default(), Default() + ) + state = inspector._state_to_json() assert state == { "info": app_info, @@ -160,25 +122,14 @@ def test_state_to_json(): } -def test_reload(): - publisher = Mock() - inspector = Inspector(publisher, {}, {}, "", 0) - inspector.reload() - - publisher.send.assert_called_once_with("__ALL_PROCESSES__:") - - -def test_shutdown(): - publisher = Mock() - inspector = Inspector(publisher, {}, {}, "", 0) - inspector.shutdown() - - publisher.send.assert_called_once_with("__TERMINATE__") - - -def test_scale(): - publisher = Mock() - inspector = Inspector(publisher, {}, {}, "", 0) - inspector.scale(3) - - publisher.send.assert_called_once_with("__SCALE__:3") +def test_run_inspector_authentication(): + inspector = Inspector( + Mock(), {}, {}, "", 0, "super-secret", Default(), Default() + )(False) + manager = TestManager(inspector.app) + _, response = manager.test_client.get("/") + assert response.status == 401 + _, response = manager.test_client.get( + "/", headers={"Authorization": "Bearer super-secret"} + ) + assert response.status == 200 diff --git a/tests/worker/test_worker_serve.py b/tests/worker/test_worker_serve.py index b24f7e8f42..a33e3cacc6 100644 --- a/tests/worker/test_worker_serve.py +++ b/tests/worker/test_worker_serve.py @@ -8,6 +8,7 @@ from sanic.app import Sanic from sanic.worker.loader import AppLoader from sanic.worker.multiplexer import WorkerMultiplexer +from sanic.worker.process import Worker, WorkerProcess from sanic.worker.serve import worker_serve @@ -40,7 +41,9 @@ def test_config_app(mock_app: Mock): def test_bad_process(mock_app: Mock, caplog): - environ["SANIC_WORKER_NAME"] = "FOO" + environ["SANIC_WORKER_NAME"] = ( + Worker.WORKER_PREFIX + WorkerProcess.SERVER_LABEL + "-FOO" + ) message = "No restart publisher found in worker process" with pytest.raises(RuntimeError, match=message): @@ -58,7 +61,9 @@ def test_bad_process(mock_app: Mock, caplog): def test_has_multiplexer(app: Sanic): - environ["SANIC_WORKER_NAME"] = "FOO" + environ["SANIC_WORKER_NAME"] = ( + Worker.WORKER_PREFIX + WorkerProcess.SERVER_LABEL + "-FOO" + ) Sanic.register_app(app) with patch("sanic.worker.serve._serve_http_1"): @@ -97,12 +102,13 @@ def test_serve_app_factory(wm: Mock, mock_app): @patch("sanic.mixins.startup.WorkerManager") -@patch("sanic.mixins.startup.Inspector") @pytest.mark.parametrize("config", (True, False)) def test_serve_with_inspector( - Inspector: Mock, WorkerManager: Mock, mock_app: Mock, config: bool + WorkerManager: Mock, mock_app: Mock, config: bool ): + Inspector = Mock() mock_app.config.INSPECTOR = config + mock_app.inspector_class = Inspector inspector = Mock() Inspector.return_value = inspector WorkerManager.return_value = WorkerManager @@ -112,7 +118,7 @@ def test_serve_with_inspector( if config: Inspector.assert_called_once() WorkerManager.manage.assert_called_once_with( - "Inspector", inspector, {}, transient=False + "Inspector", inspector, {}, transient=True ) else: Inspector.assert_not_called()