Skip to content

Commit

Permalink
Implement restart ordering (#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Dec 18, 2022
1 parent 518152d commit f7040cc
Show file tree
Hide file tree
Showing 19 changed files with 374 additions and 70 deletions.
1 change: 1 addition & 0 deletions sanic/app.py
Expand Up @@ -1534,6 +1534,7 @@ async def _startup(self):

self.state.is_started = True

def ack(self):
if hasattr(self, "multiplexer"):
self.multiplexer.ack()

Expand Down
2 changes: 1 addition & 1 deletion sanic/cli/base.py
Expand Up @@ -20,7 +20,7 @@ def add_usage(self, usage, actions, groups, prefix=None):
if not usage:
usage = SUPPRESS
# Add one linebreak, but not two
self.add_text("\x1b[1A'")
self.add_text("\x1b[1A")
super().add_usage(usage, actions, groups, prefix)


Expand Down
42 changes: 35 additions & 7 deletions sanic/cli/inspector.py
Expand Up @@ -5,10 +5,26 @@


def _add_shared(parser: ArgumentParser) -> None:
parser.add_argument("--host", "-H", default="localhost")
parser.add_argument("--port", "-p", default=6457, type=int)
parser.add_argument("--secure", "-s", action="store_true")
parser.add_argument("--api-key", "-k")
parser.add_argument(
"--host",
"-H",
default="localhost",
help="Inspector host address [default 127.0.0.1]",
)
parser.add_argument(
"--port",
"-p",
default=6457,
type=int,
help="Inspector port [default 6457]",
)
parser.add_argument(
"--secure",
"-s",
action="store_true",
help="Whether to access the Inspector via TLS encryption",
)
parser.add_argument("--api-key", "-k", help="Inspector authentication key")
parser.add_argument(
"--raw",
action="store_true",
Expand All @@ -32,17 +48,25 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
dest="action",
description=(
"Run one of the below subcommands. If you have created a custom "
"Inspector instance, then you can run custom commands.\nSee ___ "
"Inspector instance, then you can run custom commands. See ___ "
"for more details."
),
title="Required\n========\n Subcommands",
parser_class=InspectorSubParser,
)
subparsers.add_parser(
reloader = subparsers.add_parser(
"reload",
help="Trigger a reload of the server workers",
formatter_class=SanicHelpFormatter,
)
reloader.add_argument(
"--zero-downtime",
action="store_true",
help=(
"Whether to wait for the new process to be online before "
"terminating the old"
),
)
subparsers.add_parser(
"shutdown",
help="Shutdown the application and all processes",
Expand All @@ -53,7 +77,11 @@ def make_inspector_parser(parser: ArgumentParser) -> None:
help="Scale the number of workers",
formatter_class=SanicHelpFormatter,
)
scale.add_argument("replicas", type=int)
scale.add_argument(
"replicas",
type=int,
help="Number of workers requested",
)

custom = subparsers.add_parser(
"<custom>",
Expand Down
25 changes: 25 additions & 0 deletions sanic/compat.py
Expand Up @@ -4,6 +4,7 @@
import sys

from contextlib import contextmanager
from enum import Enum
from typing import Awaitable, Union

from multidict import CIMultiDict # type: ignore
Expand All @@ -30,6 +31,30 @@
except ImportError:
pass

# Python 3.11 changed the way Enum formatting works for mixed-in types.
if sys.version_info < (3, 11, 0):

class StrEnum(str, Enum):
pass

else:
from enum import StrEnum # type: ignore # noqa


class UpperStrEnum(StrEnum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()

def __eq__(self, value: object) -> bool:
value = str(value).upper()
return super().__eq__(value)

def __hash__(self) -> int:
return hash(self.value)

def __str__(self) -> str:
return self.value


@contextmanager
def use_context(method: StartMethod):
Expand Down
20 changes: 4 additions & 16 deletions sanic/constants.py
@@ -1,19 +1,9 @@
from enum import Enum, auto
from enum import auto

from sanic.compat import UpperStrEnum

class HTTPMethod(str, Enum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()

def __eq__(self, value: object) -> bool:
value = str(value).upper()
return super().__eq__(value)

def __hash__(self) -> int:
return hash(self.value)

def __str__(self) -> str:
return self.value
class HTTPMethod(UpperStrEnum):

GET = auto()
POST = auto()
Expand All @@ -24,9 +14,7 @@ def __str__(self) -> str:
DELETE = auto()


class LocalCertCreator(str, Enum):
def _generate_next_value_(name, start, count, last_values):
return name.upper()
class LocalCertCreator(UpperStrEnum):

AUTO = auto()
TRUSTME = auto()
Expand Down
4 changes: 2 additions & 2 deletions sanic/mixins/signals.py
Expand Up @@ -20,7 +20,7 @@ def signal(
event: Union[str, Enum],
*,
apply: bool = True,
condition: Dict[str, Any] = None,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
) -> Callable[[SignalHandler], SignalHandler]:
"""
Expand Down Expand Up @@ -64,7 +64,7 @@ def add_signal(
self,
handler: Optional[Callable[..., Any]],
event: str,
condition: Dict[str, Any] = None,
condition: Optional[Dict[str, Any]] = None,
exclusive: bool = True,
):
if not handler:
Expand Down
2 changes: 1 addition & 1 deletion sanic/mixins/startup.py
Expand Up @@ -851,7 +851,7 @@ def serve(
primary.config.INSPECTOR_TLS_KEY,
primary.config.INSPECTOR_TLS_CERT,
)
manager.manage("Inspector", inspector, {}, transient=True)
manager.manage("Inspector", inspector, {}, transient=False)

primary._inspector = inspector
primary._manager = manager
Expand Down
2 changes: 2 additions & 0 deletions sanic/server/runners.py
Expand Up @@ -229,6 +229,7 @@ def _serve_http_1(

loop.run_until_complete(app._startup())
loop.run_until_complete(app._server_event("init", "before"))
app.ack()

try:
http_server = loop.run_until_complete(server_coroutine)
Expand Down Expand Up @@ -306,6 +307,7 @@ def _serve_http_3(
server = AsyncioServer(app, loop, coro, [])
loop.run_until_complete(server.startup())
loop.run_until_complete(server.before_start())
app.ack()
loop.run_until_complete(server)
_setup_system_signals(app, run_multiple, register_sys_signals, loop)
loop.run_until_complete(server.after_start())
Expand Down
18 changes: 18 additions & 0 deletions sanic/worker/constants.py
@@ -0,0 +1,18 @@
from enum import IntEnum, auto

from sanic.compat import UpperStrEnum


class RestartOrder(UpperStrEnum):
SHUTDOWN_FIRST = auto()
STARTUP_FIRST = auto()


class ProcessState(IntEnum):
IDLE = auto()
RESTARTING = auto()
STARTING = auto()
STARTED = auto()
ACKED = auto()
JOINED = auto()
TERMINATED = auto()
4 changes: 3 additions & 1 deletion sanic/worker/inspector.py
Expand Up @@ -101,8 +101,10 @@ def _make_safe(obj: Dict[str, Any]) -> Dict[str, Any]:
obj[key] = value.isoformat()
return obj

def reload(self) -> None:
def reload(self, zero_downtime: bool = False) -> None:
message = "__ALL_PROCESSES__:"
if zero_downtime:
message += ":STARTUP_FIRST"
self._publisher.send(message)

def scale(self, replicas) -> str:
Expand Down
35 changes: 30 additions & 5 deletions sanic/worker/manager.py
Expand Up @@ -10,6 +10,7 @@
from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled
from sanic.log import error_logger, logger
from sanic.worker.constants import RestartOrder
from sanic.worker.process import ProcessState, Worker, WorkerProcess


Expand All @@ -20,7 +21,8 @@


class WorkerManager:
THRESHOLD = 300 # == 30 seconds
THRESHOLD = WorkerProcess.THRESHOLD
MAIN_IDENT = "Sanic-Main"

def __init__(
self,
Expand All @@ -37,7 +39,7 @@ def __init__(
self.durable: Dict[str, Worker] = {}
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid}
self.worker_state[self.MAIN_IDENT] = {"pid": self.pid}
self.terminated = False
self._serve = serve
self._server_settings = server_settings
Expand Down Expand Up @@ -119,10 +121,15 @@ def terminate(self):
process.terminate()
self.terminated = True

def restart(self, process_names: Optional[List[str]] = None, **kwargs):
def restart(
self,
process_names: Optional[List[str]] = None,
restart_order=RestartOrder.SHUTDOWN_FIRST,
**kwargs,
):
for process in self.transient_processes:
if not process_names or process.name in process_names:
process.restart(**kwargs)
process.restart(restart_order=restart_order, **kwargs)

def scale(self, num_worker: int):
if num_worker <= 0:
Expand Down Expand Up @@ -160,7 +167,12 @@ def monitor(self):
elif message == "__TERMINATE__":
self.shutdown()
break
split_message = message.split(":", 1)
logger.debug(
"Incoming monitor message: %s",
message,
extra={"verbosity": 1},
)
split_message = message.split(":", 2)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
continue
Expand All @@ -173,10 +185,17 @@ def monitor(self):
]
if "__ALL_PROCESSES__" in process_names:
process_names = None
order = (
RestartOrder.STARTUP_FIRST
if "STARTUP_FIRST" in split_message
else RestartOrder.SHUTDOWN_FIRST
)
self.restart(
process_names=process_names,
reloaded_files=reloaded_files,
restart_order=order,
)
self._sync_states()
except InterruptedError:
if not OS_IS_WINDOWS:
raise
Expand Down Expand Up @@ -263,3 +282,9 @@ def _all_workers_ack(self):
if worker_state.get("server")
]
return all(acked) and len(acked) == self.num_server

def _sync_states(self):
for process in self.processes:
state = self.worker_state[process.name].get("state")
if state and process.state.name != state:
process.set_state(ProcessState[state], True)
18 changes: 17 additions & 1 deletion sanic/worker/multiplexer.py
Expand Up @@ -2,6 +2,7 @@
from os import environ, getpid
from typing import Any, Dict

from sanic.log import Colors, logger
from sanic.worker.process import ProcessState
from sanic.worker.state import WorkerState

Expand All @@ -16,19 +17,34 @@ def __init__(
self._state = WorkerState(worker_state, self.name)

def ack(self):
logger.debug(
f"{Colors.BLUE}Process ack: {Colors.BOLD}{Colors.SANIC}"
f"%s {Colors.BLUE}[%s]{Colors.END}",
self.name,
self.pid,
)
self._state._state[self.name] = {
**self._state._state[self.name],
"state": ProcessState.ACKED.name,
}

def restart(self, name: str = "", all_workers: bool = False):
def restart(
self,
name: str = "",
all_workers: bool = False,
zero_downtime: bool = False,
):
if name and all_workers:
raise ValueError(
"Ambiguous restart with both a named process and"
" all_workers=True"
)
if not name:
name = "__ALL_PROCESSES__:" if all_workers else self.name
if not name.endswith(":"):
name += ":"
if zero_downtime:
name += ":STARTUP_FIRST"
self._monitor_publisher.send(name)

reload = restart # no cov
Expand Down

0 comments on commit f7040cc

Please sign in to comment.