Skip to content

Commit

Permalink
Scale workers (#2617)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Dec 13, 2022
1 parent 13e9ab7 commit db39e12
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 102 deletions.
26 changes: 21 additions & 5 deletions sanic/cli/app.py
Expand Up @@ -98,16 +98,32 @@ def run(self, parse_args=None):
except ValueError as e:
error_logger.exception(f"Failed to run app: {e}")
else:
if self.args.inspect or self.args.inspect_raw or self.args.trigger:
if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)

if self.args.inspect or self.args.inspect_raw or self.args.trigger:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
if self.args.scale is not None:
if self.args.scale <= 0:
error_logger.error("There must be at least 1 worker")
sys.exit(1)
action = f"scale={self.args.scale}"
else:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
inspect(
app.config.INSPECTOR_HOST,
app.config.INSPECTOR_PORT,
Expand Down
6 changes: 6 additions & 0 deletions sanic/cli/arguments.py
Expand Up @@ -115,6 +115,12 @@ def attach(self):
const="shutdown",
help=("Trigger all processes to shutdown"),
)
group.add_argument(
"--scale",
dest="scale",
type=int,
help=("Scale number of workers"),
)


class HTTPVersionGroup(Group):
Expand Down
15 changes: 11 additions & 4 deletions sanic/worker/inspector.py
Expand Up @@ -55,17 +55,20 @@ def __call__(self) -> None:
else:
action = conn.recv(64)
if action == b"reload":
conn.send(b"\n")
self.reload()
elif action == b"shutdown":
conn.send(b"\n")
self.shutdown()
elif action.startswith(b"scale"):
num_workers = int(action.split(b"=", 1)[-1])
logger.info("Scaling to %s", num_workers)
self.scale(num_workers)
else:
data = dumps(self.state_to_json())
conn.send(data.encode())
conn.close()
conn.send(b"\n")
conn.close()
finally:
logger.debug("Inspector closing")
logger.info("Inspector closing")
sock.close()

def stop(self, *_):
Expand All @@ -80,6 +83,10 @@ def reload(self):
message = "__ALL_PROCESSES__:"
self._publisher.send(message)

def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._publisher.send(message)

def shutdown(self):
message = "__TERMINATE__"
self._publisher.send(message)
Expand Down
87 changes: 71 additions & 16 deletions sanic/worker/manager.py
@@ -1,8 +1,10 @@
import os

from itertools import count
from random import choice
from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from typing import List, Optional
from typing import Dict, List, Optional

from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled
Expand Down Expand Up @@ -30,33 +32,61 @@ def __init__(
):
self.num_server = number
self.context = context
self.transient: List[Worker] = []
self.durable: List[Worker] = []
self.transient: Dict[str, Worker] = {}
self.durable: Dict[str, Worker] = {}
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid}
self.terminated = False
self._serve = serve
self._server_settings = server_settings
self._server_count = count()

if number == 0:
raise RuntimeError("Cannot serve with no workers")

for i in range(number):
self.manage(
f"{WorkerProcess.SERVER_LABEL}-{i}",
serve,
server_settings,
transient=True,
)
for _ in range(number):
self.create_server()

signal_func(SIGINT, self.shutdown_signal)
signal_func(SIGTERM, self.shutdown_signal)

def manage(self, ident, func, kwargs, transient=False):
def manage(self, ident, func, kwargs, transient=False) -> Worker:
container = self.transient if transient else self.durable
container.append(
Worker(ident, func, kwargs, self.context, self.worker_state)
worker = Worker(ident, func, kwargs, self.context, self.worker_state)
container[worker.ident] = worker
return worker

def create_server(self) -> Worker:
server_number = next(self._server_count)
return self.manage(
f"{WorkerProcess.SERVER_LABEL}-{server_number}",
self._serve,
self._server_settings,
transient=True,
)

def shutdown_server(self, ident: Optional[str] = None) -> None:
if not ident:
servers = [
worker
for worker in self.transient.values()
if worker.ident.startswith(WorkerProcess.SERVER_LABEL)
]
if not servers:
error_logger.error(
"Server shutdown failed because a server was not found."
)
return
worker = choice(servers) # nosec B311
else:
worker = self.transient[ident]

for process in worker.processes:
process.terminate()

del self.transient[worker.ident]

def run(self):
self.start()
self.monitor()
Expand Down Expand Up @@ -94,6 +124,28 @@ def restart(self, process_names: Optional[List[str]] = None, **kwargs):
if not process_names or process.name in process_names:
process.restart(**kwargs)

def scale(self, num_worker: int):
if num_worker <= 0:
raise ValueError("Cannot scale to 0 workers.")

change = num_worker - self.num_server
if change == 0:
logger.info(
f"No change needed. There are already {num_worker} workers."
)
return

logger.info(f"Scaling from {self.num_server} to {num_worker} workers")
if change > 0:
for _ in range(change):
worker = self.create_server()
for process in worker.processes:
process.start()
else:
for _ in range(abs(change)):
self.shutdown_server()
self.num_server = num_worker

def monitor(self):
self.wait_for_ack()
while True:
Expand All @@ -109,6 +161,9 @@ def monitor(self):
self.shutdown()
break
split_message = message.split(":", 1)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
continue
processes = split_message[0]
reloaded_files = (
split_message[1] if len(split_message) > 1 else None
Expand Down Expand Up @@ -161,8 +216,8 @@ def wait_for_ack(self): # no cov
self.kill()

@property
def workers(self):
return self.transient + self.durable
def workers(self) -> List[Worker]:
return list(self.transient.values()) + list(self.durable.values())

@property
def processes(self):
Expand All @@ -172,7 +227,7 @@ def processes(self):

@property
def transient_processes(self):
for worker in self.transient:
for worker in self.transient.values():
for process in worker.processes:
yield process

Expand Down
4 changes: 4 additions & 0 deletions sanic/worker/multiplexer.py
Expand Up @@ -33,6 +33,10 @@ def restart(self, name: str = "", all_workers: bool = False):

reload = restart # no cov

def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._monitor_publisher.send(message)

def terminate(self, early: bool = False):
message = "__TERMINATE_EARLY__" if early else "__TERMINATE__"
self._monitor_publisher.send(message)
Expand Down
4 changes: 3 additions & 1 deletion sanic/worker/process.py
Expand Up @@ -133,6 +133,8 @@ def pid(self):


class Worker:
WORKER_PREFIX = "Sanic-"

def __init__(
self,
ident: str,
Expand All @@ -152,7 +154,7 @@ def __init__(
def create_process(self) -> WorkerProcess:
process = WorkerProcess(
factory=self.context.Process,
name=f"Sanic-{self.ident}-{len(self.processes)}",
name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}",
target=self.serve,
kwargs={**self.server_settings},
worker_state=self.worker_state,
Expand Down
25 changes: 21 additions & 4 deletions tests/worker/test_inspector.py
Expand Up @@ -74,7 +74,9 @@ def test_send_inspect_conn_refused(socket: Mock, sys: Mock, caplog):


@patch("sanic.worker.inspector.configure_socket")
@pytest.mark.parametrize("action", (b"reload", b"shutdown", b"foo"))
@pytest.mark.parametrize(
"action", (b"reload", b"shutdown", b"scale=5", b"foo")
)
def test_run_inspector(configure_socket: Mock, action: bytes):
sock = Mock()
conn = Mock()
Expand All @@ -83,6 +85,7 @@ def test_run_inspector(configure_socket: Mock, action: bytes):
inspector = Inspector(Mock(), {}, {}, "localhost", 9999)
inspector.reload = Mock() # type: ignore
inspector.shutdown = Mock() # type: ignore
inspector.scale = Mock() # type: ignore
inspector.state_to_json = Mock(return_value="foo") # type: ignore

def accept():
Expand All @@ -98,20 +101,26 @@ def accept():
)
conn.recv.assert_called_with(64)

conn.send.assert_called_with(b"\n")
if action == b"reload":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called()
elif action == b"shutdown":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_not_called()
inspector.shutdown.assert_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called()
elif action.startswith(b"scale"):
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_called_once_with(5)
inspector.state_to_json.assert_not_called()
else:
conn.send.assert_called_with(b'"foo"')
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_called()


Expand Down Expand Up @@ -165,3 +174,11 @@ def test_shutdown():
inspector.shutdown()

publisher.send.assert_called_once_with("__TERMINATE__")


def test_scale():
publisher = Mock()
inspector = Inspector(publisher, {}, {}, "", 0)
inspector.scale(3)

publisher.send.assert_called_once_with("__SCALE__:3")

0 comments on commit db39e12

Please sign in to comment.