Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for TLS on inspector #2620

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 26 additions & 18 deletions sanic/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sanic.application.logo import get_logo
from sanic.cli.arguments import Group
from sanic.log import error_logger
from sanic.worker.inspector import inspect
from sanic.worker.inspector import InspectorClient
from sanic.worker.loader import AppLoader


Expand Down Expand Up @@ -92,29 +92,20 @@ def run(self, parse_args=None):
self.args,
)

inspector_command = (
self.args.inspect or self.args.inspect_raw or self.args.trigger
)
if inspector_command:
self._inspector(app_loader)
return
try:
app = self._get_app(app_loader)
kwargs = self._build_run_kwargs()
except ValueError as e:
error_logger.exception(f"Failed to run app: {e}")
else:
if self.args.inspect or self.args.inspect_raw or self.args.trigger:
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)

if self.args.inspect or self.args.inspect_raw or self.args.trigger:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
inspect(
app.config.INSPECTOR_HOST,
app.config.INSPECTOR_PORT,
action,
)
del os.environ["SANIC_IGNORE_PRODUCTION_WARNING"]
return
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)

if self.args.single:
serve = Sanic.serve_single
Expand All @@ -124,6 +115,23 @@ def run(self, parse_args=None):
serve = partial(Sanic.serve, app_loader=app_loader)
serve(app)

def _inspector(self, app_loader: AppLoader):
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
host = port = None
if ":" in self.args.module:
maybe_host, maybe_port = self.args.module.split(":", 1)
if maybe_port.isnumeric():
host, port = maybe_host, int(maybe_port)
if not host:
app = self._get_app(app_loader)
host, port = app.config.INSPECTOR_HOST, app.config.INSPECTOR_PORT

action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
InspectorClient(host, port, self.args.secure).run(action)
del os.environ["SANIC_IGNORE_PRODUCTION_WARNING"]

def _precheck(self):
# Custom TLS mismatch handling for better diagnostics
if self.main_process and (
Expand Down
6 changes: 6 additions & 0 deletions sanic/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def attach(self):
const="shutdown",
help=("Trigger all processes to shutdown"),
)
self.container.add_argument(
"--secure",
dest="secure",
action="store_true",
help=("Whether to connect to the inspector over a TLS connection"),
)


class HTTPVersionGroup(Group):
Expand Down
4 changes: 4 additions & 0 deletions sanic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"INSPECTOR": False,
"INSPECTOR_HOST": "localhost",
"INSPECTOR_PORT": 6457,
"INSPECTOR_TLS_KEY": _default,
"INSPECTOR_TLS_CERT": _default,
"KEEP_ALIVE_TIMEOUT": 5, # 5 seconds
"KEEP_ALIVE": True,
"LOCAL_CERT_CREATOR": LocalCertCreator.AUTO,
Expand Down Expand Up @@ -93,6 +95,8 @@ class Config(dict, metaclass=DescriptorMeta):
INSPECTOR: bool
INSPECTOR_HOST: str
INSPECTOR_PORT: int
INSPECTOR_TLS_KEY: Union[Path, str, Default]
INSPECTOR_TLS_CERT: Union[Path, str, Default]
KEEP_ALIVE_TIMEOUT: int
KEEP_ALIVE: bool
LOCAL_CERT_CREATOR: Union[str, LocalCertCreator]
Expand Down
2 changes: 2 additions & 0 deletions sanic/mixins/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ def serve(
worker_state,
primary.config.INSPECTOR_HOST,
primary.config.INSPECTOR_PORT,
primary.config.INSPECTOR_TLS_KEY,
primary.config.INSPECTOR_TLS_CERT,
)
manager.manage("Inspector", inspector, {}, transient=False)

Expand Down
164 changes: 111 additions & 53 deletions sanic/worker/inspector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import logging
import ssl
import sys

from contextlib import contextmanager
from datetime import datetime
from multiprocessing.connection import Connection
from pathlib import Path
from signal import SIGINT, SIGTERM
from signal import signal as signal_func
from socket import AF_INET, SOCK_STREAM, socket, timeout
from socket import create_connection, timeout
from textwrap import indent
from typing import Any, Dict
from typing import Any, Dict, Union

from sanic.application.logo import get_logo
from sanic.application.motd import MOTDTTY
from sanic.log import Colors, error_logger, logger
from sanic.helpers import Default
from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, error_logger, logger
from sanic.server.socket import configure_socket


Expand All @@ -19,6 +24,9 @@
except ModuleNotFoundError: # no cov
from json import dumps, loads # type: ignore

OK = b"OK"
ER = b"ER"


class Inspector:
def __init__(
Expand All @@ -28,31 +36,41 @@ def __init__(
worker_state: Dict[str, Any],
host: str,
port: int,
tls_key: Union[Path, str, Default],
tls_cert: Union[Path, str, Default],
):
self._publisher = publisher
self.run = True
self.app_info = app_info
self.worker_state = worker_state
self.host = host
self.port = port
self.tls_key = tls_key
self.tls_cert = tls_cert

def __call__(self) -> None:
sock = configure_socket(
{"host": self.host, "port": self.port, "unix": None, "backlog": 1}
)
assert sock
signal_func(SIGINT, self.stop)
signal_func(SIGTERM, self.stop)

logger.info(f"Inspector started on: {sock.getsockname()}")
sock.settimeout(0.5)
try:
with self.socket() as sock:
host, port = sock.getsockname()
logger.info(f"Inspector started @ {host}:{port}")

while self.run:
try:
conn, _ = sock.accept()
except timeout:
continue
except ssl.SSLError as e:
print("SSL error: ", e)
continue
else:
okay = conn.recv(2)
if okay != OK:
error_logger.error("Invalid start")
conn.close()
continue
conn.send(OK)
action = conn.recv(64)
if action == b"reload":
conn.send(b"\n")
Expand All @@ -63,10 +81,8 @@ def __call__(self) -> None:
else:
data = dumps(self.state_to_json())
conn.send(data.encode())
conn.close()
finally:
logger.debug("Inspector closing")
sock.close()
conn.send(b"\r\n\r\n")
conn.close()

def stop(self, *_):
self.run = False
Expand All @@ -84,6 +100,25 @@ def shutdown(self):
message = "__TERMINATE__"
self._publisher.send(message)

@contextmanager
def socket(self):
sock = configure_socket(
{"host": self.host, "port": self.port, "unix": None, "backlog": 1}
)
assert sock
sock.settimeout(15)

if not isinstance(self.tls_key, Default) and not isinstance(
self.tls_cert, Default
):
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(self.tls_cert, self.tls_key)
yield context.wrap_socket(sock, server_side=True)

Check failure

Code scanning / CodeQL

Use of insecure SSL/TLS version

Insecure SSL/TLS protocol version TLSv1 allowed by [call to ssl.SSLContext](1). Insecure SSL/TLS protocol version TLSv1_1 allowed by [call to ssl.SSLContext](1).
else:
yield sock
logger.debug("Inspector closing")
sock.close()

@staticmethod
def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]:
for key, value in obj.items():
Expand All @@ -94,49 +129,72 @@ def make_safe(obj: Dict[str, Any]) -> Dict[str, Any]:
return obj


def inspect(host: str, port: int, action: str):
out = sys.stdout.write
with socket(AF_INET, SOCK_STREAM) as sock:
class InspectorClient:
def __init__(self, host: str, port: int, secure: bool):
self.host = host
self.port = port
self.secure = secure

@contextmanager
def socket(self):
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
error = sys.stderr.write
try:
sock.connect((host, port))
except ConnectionRefusedError:
error_logger.error(
sock = create_connection((self.host, self.port))
sock.settimeout(15)
if self.secure:
sock = context.wrap_socket(sock, server_hostname="localhost")

Check failure

Code scanning / CodeQL

Use of insecure SSL/TLS version

Insecure SSL/TLS protocol version TLSv1 allowed by [call to ssl.create_default_context](1). Insecure SSL/TLS protocol version TLSv1_1 allowed by [call to ssl.create_default_context](1).

sock.sendall(OK)
if sock.recv(2) != OK:
raise ValueError
yield sock
if not sock._closed:
sock.close()
except (
ConnectionRefusedError,
ConnectionResetError,
TimeoutError,
ValueError,
):
error(
f"{Colors.RED}Could not connect to inspector at: "
f"{Colors.YELLOW}{(host, port)}{Colors.END}\n"
f"{Colors.YELLOW}{self.host}:{self.port}{Colors.END}\n"
"Either the application is not running, or it did not start "
"an inspector instance."
"an inspector instance.\n"
)
sock.close()
sys.exit(1)
sock.sendall(action.encode())
data = sock.recv(4096)
if action == "raw":
out(data.decode())
elif action == "pretty":
loaded = loads(data)
display = loaded.pop("info")
extra = display.pop("extra", {})
display["packages"] = ", ".join(display["packages"])
MOTDTTY(get_logo(), f"{host}:{port}", display, extra).display(
version=False,
action="Inspecting",
out=out,
)
for name, info in loaded["workers"].items():
info = "\n".join(
f"\t{key}: {Colors.BLUE}{value}{Colors.END}"
for key, value in info.items()

def run(self, action: str):
out = sys.stdout.write
logging.config.dictConfig(LOGGING_CONFIG_DEFAULTS)
with self.socket() as sock:
sock.sendall(action.encode())
more = True
data = b""
while more:
received = sock.recv(4096)
if received.endswith(b"\r\n\r\n"):
more = False
data += received
if action == "raw":
out(data.decode())
elif action == "pretty":
loaded = loads(data)
display = loaded.pop("info")
extra = display.pop("extra", {})
display["packages"] = ", ".join(display["packages"])
MOTDTTY(
get_logo(), f"{self.host}:{self.port}", display, extra
).display(
version=False,
action="Inspecting",
out=out,
)
out(
"\n"
+ indent(
"\n".join(
[
f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}",
info,
]
),
" ",
for name, info in loaded["workers"].items():
info = "\n".join(
f"\t{key}: {Colors.BLUE}{value}{Colors.END}"
for key, value in info.items()
)
+ "\n"
)
name = f"{Colors.BOLD}{Colors.SANIC}{name}{Colors.END}"
out("\n" + indent("\n".join([name, info]), " ") + "\n")