Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scale workers #2617

Merged
merged 6 commits into from Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 21 additions & 5 deletions sanic/cli/app.py
Expand Up @@ -98,16 +98,32 @@ def run(self, parse_args=None):
except ValueError as e:
error_logger.exception(f"Failed to run app: {e}")
else:
if self.args.inspect or self.args.inspect_raw or self.args.trigger:
if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true"
else:
for http_version in self.args.http:
app.prepare(**kwargs, version=http_version)

if self.args.inspect or self.args.inspect_raw or self.args.trigger:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
if (
self.args.inspect
or self.args.inspect_raw
or self.args.trigger
or self.args.scale is not None
):
if self.args.scale is not None:
if self.args.scale <= 0:
error_logger.error("There must be at least 1 worker")
sys.exit(1)
action = f"scale={self.args.scale}"
else:
action = self.args.trigger or (
"raw" if self.args.inspect_raw else "pretty"
)
inspect(
app.config.INSPECTOR_HOST,
app.config.INSPECTOR_PORT,
Expand Down
6 changes: 6 additions & 0 deletions sanic/cli/arguments.py
Expand Up @@ -115,6 +115,12 @@ def attach(self):
const="shutdown",
help=("Trigger all processes to shutdown"),
)
group.add_argument(
"--scale",
dest="scale",
type=int,
help=("Scale number of workers"),
)


class HTTPVersionGroup(Group):
Expand Down
15 changes: 11 additions & 4 deletions sanic/worker/inspector.py
Expand Up @@ -55,17 +55,20 @@ def __call__(self) -> None:
else:
action = conn.recv(64)
if action == b"reload":
conn.send(b"\n")
self.reload()
elif action == b"shutdown":
conn.send(b"\n")
self.shutdown()
elif action.startswith(b"scale"):
num_workers = int(action.split(b"=", 1)[-1])
logger.info("Scaling to %s", num_workers)
self.scale(num_workers)
else:
data = dumps(self.state_to_json())
conn.send(data.encode())
conn.close()
conn.send(b"\n")
conn.close()
finally:
logger.debug("Inspector closing")
logger.info("Inspector closing")
sock.close()

def stop(self, *_):
Expand All @@ -80,6 +83,10 @@ def reload(self):
message = "__ALL_PROCESSES__:"
self._publisher.send(message)

def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._publisher.send(message)

def shutdown(self):
message = "__TERMINATE__"
self._publisher.send(message)
Expand Down
87 changes: 71 additions & 16 deletions sanic/worker/manager.py
@@ -1,8 +1,10 @@
import os

from itertools import count
from random import choice
from signal import SIGINT, SIGTERM, Signals
from signal import signal as signal_func
from typing import List, Optional
from typing import Dict, List, Optional

from sanic.compat import OS_IS_WINDOWS
from sanic.exceptions import ServerKilled
Expand Down Expand Up @@ -30,33 +32,61 @@ def __init__(
):
self.num_server = number
self.context = context
self.transient: List[Worker] = []
self.durable: List[Worker] = []
self.transient: Dict[str, Worker] = {}
self.durable: Dict[str, Worker] = {}
self.monitor_publisher, self.monitor_subscriber = monitor_pubsub
self.worker_state = worker_state
self.worker_state["Sanic-Main"] = {"pid": self.pid}
self.terminated = False
self._serve = serve
self._server_settings = server_settings
self._server_count = count()

if number == 0:
raise RuntimeError("Cannot serve with no workers")

for i in range(number):
self.manage(
f"{WorkerProcess.SERVER_LABEL}-{i}",
serve,
server_settings,
transient=True,
)
for _ in range(number):
self.create_server()

signal_func(SIGINT, self.shutdown_signal)
signal_func(SIGTERM, self.shutdown_signal)

def manage(self, ident, func, kwargs, transient=False):
def manage(self, ident, func, kwargs, transient=False) -> Worker:
container = self.transient if transient else self.durable
container.append(
Worker(ident, func, kwargs, self.context, self.worker_state)
worker = Worker(ident, func, kwargs, self.context, self.worker_state)
container[worker.ident] = worker
return worker

def create_server(self) -> Worker:
server_number = next(self._server_count)
return self.manage(
f"{WorkerProcess.SERVER_LABEL}-{server_number}",
self._serve,
self._server_settings,
transient=True,
)

def shutdown_server(self, ident: Optional[str] = None) -> None:
if not ident:
servers = [
worker
for worker in self.transient.values()
if worker.ident.startswith(WorkerProcess.SERVER_LABEL)
]
if not servers:
error_logger.error(
"Server shutdown failed because a server was not found."
)
return
worker = choice(servers) # nosec B311
else:
worker = self.transient[ident]

for process in worker.processes:
process.terminate()

del self.transient[worker.ident]

def run(self):
self.start()
self.monitor()
Expand Down Expand Up @@ -94,6 +124,28 @@ def restart(self, process_names: Optional[List[str]] = None, **kwargs):
if not process_names or process.name in process_names:
process.restart(**kwargs)

def scale(self, num_worker: int):
if num_worker <= 0:
raise ValueError("Cannot scale to 0 workers.")

change = num_worker - self.num_server
if change == 0:
logger.info(
f"No change needed. There are already {num_worker} workers."
)
return

logger.info(f"Scaling from {self.num_server} to {num_worker} workers")
if change > 0:
for _ in range(change):
worker = self.create_server()
for process in worker.processes:
process.start()
else:
for _ in range(abs(change)):
self.shutdown_server()
self.num_server = num_worker

def monitor(self):
self.wait_for_ack()
while True:
Expand All @@ -109,6 +161,9 @@ def monitor(self):
self.shutdown()
break
split_message = message.split(":", 1)
if message.startswith("__SCALE__"):
self.scale(int(split_message[-1]))
continue
processes = split_message[0]
reloaded_files = (
split_message[1] if len(split_message) > 1 else None
Expand Down Expand Up @@ -161,8 +216,8 @@ def wait_for_ack(self): # no cov
self.kill()

@property
def workers(self):
return self.transient + self.durable
def workers(self) -> List[Worker]:
return list(self.transient.values()) + list(self.durable.values())

@property
def processes(self):
Expand All @@ -172,7 +227,7 @@ def processes(self):

@property
def transient_processes(self):
for worker in self.transient:
for worker in self.transient.values():
for process in worker.processes:
yield process

Expand Down
4 changes: 4 additions & 0 deletions sanic/worker/multiplexer.py
Expand Up @@ -33,6 +33,10 @@ def restart(self, name: str = "", all_workers: bool = False):

reload = restart # no cov

def scale(self, num_workers: int):
message = f"__SCALE__:{num_workers}"
self._monitor_publisher.send(message)

def terminate(self, early: bool = False):
message = "__TERMINATE_EARLY__" if early else "__TERMINATE__"
self._monitor_publisher.send(message)
Expand Down
4 changes: 3 additions & 1 deletion sanic/worker/process.py
Expand Up @@ -133,6 +133,8 @@ def pid(self):


class Worker:
WORKER_PREFIX = "Sanic-"

def __init__(
self,
ident: str,
Expand All @@ -152,7 +154,7 @@ def __init__(
def create_process(self) -> WorkerProcess:
process = WorkerProcess(
factory=self.context.Process,
name=f"Sanic-{self.ident}-{len(self.processes)}",
name=f"{self.WORKER_PREFIX}{self.ident}-{len(self.processes)}",
target=self.serve,
kwargs={**self.server_settings},
worker_state=self.worker_state,
Expand Down
25 changes: 21 additions & 4 deletions tests/worker/test_inspector.py
Expand Up @@ -74,7 +74,9 @@ def test_send_inspect_conn_refused(socket: Mock, sys: Mock, caplog):


@patch("sanic.worker.inspector.configure_socket")
@pytest.mark.parametrize("action", (b"reload", b"shutdown", b"foo"))
@pytest.mark.parametrize(
"action", (b"reload", b"shutdown", b"scale=5", b"foo")
)
def test_run_inspector(configure_socket: Mock, action: bytes):
sock = Mock()
conn = Mock()
Expand All @@ -83,6 +85,7 @@ def test_run_inspector(configure_socket: Mock, action: bytes):
inspector = Inspector(Mock(), {}, {}, "localhost", 9999)
inspector.reload = Mock() # type: ignore
inspector.shutdown = Mock() # type: ignore
inspector.scale = Mock() # type: ignore
inspector.state_to_json = Mock(return_value="foo") # type: ignore

def accept():
Expand All @@ -98,20 +101,26 @@ def accept():
)
conn.recv.assert_called_with(64)

conn.send.assert_called_with(b"\n")
if action == b"reload":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called()
elif action == b"shutdown":
conn.send.assert_called_with(b"\n")
inspector.reload.assert_not_called()
inspector.shutdown.assert_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_not_called()
elif action.startswith(b"scale"):
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_called_once_with(5)
inspector.state_to_json.assert_not_called()
else:
conn.send.assert_called_with(b'"foo"')
inspector.reload.assert_not_called()
inspector.shutdown.assert_not_called()
inspector.scale.assert_not_called()
inspector.state_to_json.assert_called()


Expand Down Expand Up @@ -165,3 +174,11 @@ def test_shutdown():
inspector.shutdown()

publisher.send.assert_called_once_with("__TERMINATE__")


def test_scale():
publisher = Mock()
inspector = Inspector(publisher, {}, {}, "", 0)
inspector.scale(3)

publisher.send.assert_called_once_with("__SCALE__:3")