Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add named tasks #2304

Merged
merged 13 commits into from Dec 20, 2021
220 changes: 176 additions & 44 deletions sanic/app.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
import logging.config
import os
Expand All @@ -11,6 +12,7 @@
AbstractEventLoop,
CancelledError,
Protocol,
Task,
ensure_future,
get_event_loop,
wait_for,
Expand Down Expand Up @@ -125,6 +127,7 @@ class Sanic(BaseSanic, metaclass=TouchUpMeta):
"_future_signals",
"_future_statics",
"_state",
"_task_registry",
"_test_client",
"_test_manager",
"asgi",
Expand Down Expand Up @@ -188,17 +191,22 @@ def __init__(
"load_env or env_prefix"
)

self.config: Config = config or Config(
load_env=load_env,
env_prefix=env_prefix,
)

self._asgi_client: Any = None
self._test_client: Any = None
self._test_manager: Any = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
self._future_registry: FutureRegistry = FutureRegistry()
self._state: ApplicationState = ApplicationState(app=self)
self._task_registry: Dict[str, Task] = {}
self._test_client: Any = None
self._test_manager: Any = None
self.asgi = False
self.auto_reload = False
self.blueprints: Dict[str, Blueprint] = {}
self.config: Config = config or Config(
load_env=load_env, env_prefix=env_prefix
)
self.configure_logging: bool = configure_logging
self.ctx: Any = ctx or SimpleNamespace()
self.debug = False
Expand Down Expand Up @@ -250,32 +258,6 @@ def loop(self):
# Registration
# -------------------------------------------------------------------- #

def add_task(
self,
task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]],
) -> None:
"""
Schedule a task to run later, after the loop has started.
Different from asyncio.ensure_future in that it does not
also return a future, and the actual ensure_future call
is delayed until before server start.

`See user guide re: background tasks
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`__

:param task: future, couroutine or awaitable
"""
try:
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))

self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)

def register_listener(
self, listener: ListenerType[SanicVar], event: str
) -> ListenerType[SanicVar]:
Expand Down Expand Up @@ -1183,6 +1165,7 @@ def stop(self):
This kills the Sanic
"""
if not self.is_stopping:
self.shutdown_tasks(timeout=0)
self.is_stopping = True
get_event_loop().stop()

Expand Down Expand Up @@ -1456,7 +1439,29 @@ def _build_endpoint_name(self, *parts):
return ".".join(parts)

@classmethod
def _prep_task(cls, task, app, loop):
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()

@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro

# -------------------------------------------------------------------- #
# Task management
# -------------------------------------------------------------------- #

@classmethod
def _prep_task(
cls,
task,
app,
loop,
):
if callable(task):
try:
task = task(app)
Expand All @@ -1466,14 +1471,22 @@ def _prep_task(cls, task, app, loop):
return task

@classmethod
def _loop_add_task(cls, task, app, loop):
def _loop_add_task(
cls,
task,
app,
loop,
*,
name: Optional[str] = None,
register: bool = True,
) -> Task:
prepped = cls._prep_task(task, app, loop)
loop.create_task(prepped)
task = loop.create_task(prepped, name=name)

@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()
if name and register:
app._task_registry[name] = task

return task

@staticmethod
async def dispatch_delayed_tasks(app, loop):
Expand All @@ -1486,13 +1499,132 @@ async def run_delayed_task(app, loop, task):
prepped = app._prep_task(task, app, loop)
await prepped

@staticmethod
async def _listener(
app: Sanic, loop: AbstractEventLoop, listener: ListenerType
def add_task(
self,
task: Union[Future[Any], Coroutine[Any, Any, Any], Awaitable[Any]],
*,
name: Optional[str] = None,
register: bool = True,
) -> Optional[Task]:
"""
Schedule a task to run later, after the loop has started.
Different from asyncio.ensure_future in that it does not
also return a future, and the actual ensure_future call
is delayed until before server start.

`See user guide re: background tasks
<https://sanicframework.org/guide/basics/tasks.html#background-tasks>`__

:param task: future, couroutine or awaitable
"""
if name and sys.version_info == (3, 7):
name = None
error_logger.warning(
"Cannot set a name for a task when using Python 3.7. Your "
"task will be created without a name."
)
try:
loop = self.loop # Will raise SanicError if loop is not started
return self._loop_add_task(
task, self, loop, name=name, register=register
)
except SanicException:
task_name = f"sanic.delayed_task.{hash(task)}"
if not self._delayed_tasks:
self.after_server_start(partial(self.dispatch_delayed_tasks))

if name:
raise RuntimeError(
"Cannot name task outside of a running application"
)

self.signal(task_name)(partial(self.run_delayed_task, task=task))
self._delayed_tasks.append(task_name)
return None

def get_task(
self, name: str, *, raise_exception: bool = True
) -> Optional[Task]:
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
try:
return self._task_registry[name]
except KeyError:
if raise_exception:
raise SanicException(
f'Registered task named "{name}" not found.'
)
return None

async def cancel_task(
self,
name: str,
msg: Optional[str] = None,
*,
raise_exception: bool = True,
) -> None:
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
task = self.get_task(name, raise_exception=raise_exception)
if task and not task.cancelled():
args: Tuple[str, ...] = ()
if msg:
if sys.version_info >= (3, 9):
args = (msg,)
else:
raise RuntimeError(
"Cancelling a task with a message is only supported "
"on Python 3.9+."
)
task.cancel(*args)
try:
await task
except CancelledError:
...

def purge_tasks(self):
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
for task in self.tasks:
if task.done() or task.cancelled():
name = task.get_name()
self._task_registry[name] = None

self._task_registry = {
k: v for k, v in self._task_registry.items() if v is not None
}

def shutdown_tasks(
self, timeout: Optional[float] = None, increment: float = 0.1
):
maybe_coro = listener(app, loop)
if maybe_coro and isawaitable(maybe_coro):
await maybe_coro
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
for task in self.tasks:
task.cancel()

if timeout is None:
timeout = self.config.GRACEFUL_SHUTDOWN_TIMEOUT

while len(self._task_registry) and timeout:
self.loop.run_until_complete(asyncio.sleep(increment))
self.purge_tasks()
timeout -= increment

@property
def tasks(self):
if sys.version_info == (3, 7):
raise RuntimeError(
"This feature is only supported on using Python 3.8+."
)
return iter(self._task_registry.values())

# -------------------------------------------------------------------- #
# ASGI
Expand Down
5 changes: 5 additions & 0 deletions sanic/server/runners.py
@@ -1,5 +1,7 @@
from __future__ import annotations

import sys

from ssl import SSLContext
from typing import TYPE_CHECKING, Dict, Optional, Type, Union

Expand Down Expand Up @@ -174,6 +176,9 @@ def serve(
loop.run_until_complete(asyncio.sleep(0.1))
start_shutdown = start_shutdown + 0.1

if sys.version_info > (3, 7):
app.shutdown_tasks(graceful - start_shutdown)

# Force close non-idle connection after waiting for
# graceful_shutdown_timeout
for conn in connections:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_create_task.py
@@ -1,7 +1,11 @@
import asyncio
import sys

from threading import Event

import pytest

from sanic.exceptions import SanicException
from sanic.response import text


Expand Down Expand Up @@ -48,3 +52,41 @@ async def coro(app):

_, response = app.test_client.get("/")
assert response.text == "test_create_task_with_app_arg"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
def test_create_named_task(app):
async def dummy():
...

@app.before_server_start
async def setup(app, _):
app.add_task(dummy, name="dummy_task")

@app.after_server_start
async def stop(app, _):
task = app.get_task("dummy_task")

assert app._task_registry
assert isinstance(task, asyncio.Task)

assert task.get_name() == "dummy_task"

app.stop()

app.run()


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Not supported in 3.7")
def test_create_named_task_fails_outside_app(app):
async def dummy():
...

message = "Cannot name task outside of a running application"
with pytest.raises(RuntimeError, match=message):
app.add_task(dummy, name="dummy_task")
assert not app._task_registry

message = 'Registered task named "dummy_task" not found.'
with pytest.raises(SanicException, match=message):
app.get_task("dummy_task")