Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run lifespans of mounted applications #1988

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions starlette/responses.py
Expand Up @@ -157,6 +157,8 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"

await send(
{
"type": "http.response.start",
Expand Down
112 changes: 105 additions & 7 deletions starlette/routing.py
Expand Up @@ -6,9 +6,11 @@
import types
import typing
import warnings
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum

from anyio import Event, create_task_group

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
Expand All @@ -17,7 +19,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -572,6 +574,91 @@ def __call__(self: _T, app: object) -> _T:
return self


class LifespanException(Exception):
def __init__(self, message: str):
self.message = message


@asynccontextmanager
async def _app_lifespan(scope: Scope, app: ASGIApp) -> typing.AsyncIterator[None]:
startup_sent = Event()
startup_done = Event()

shutdown_init = Event()
shutdown_sent = Event()
shutdown_done = Event()

lifespan_supported = False
exception = None

async def receive() -> Message:
nonlocal lifespan_supported

lifespan_supported = True

if not startup_sent.is_set():
startup_sent.set()
return {"type": "lifespan.startup"}

elif startup_done.is_set() and not shutdown_sent.is_set():
await shutdown_init.wait()
shutdown_sent.set()
return {"type": "lifespan.shutdown"}

else:
raise RuntimeError("unexpected receive")

async def send(message: Message) -> None:
nonlocal exception, lifespan_supported

lifespan_supported = True

if startup_sent.is_set() and not startup_done.is_set():
if message["type"] == "lifespan.startup.complete":
pass
elif message["type"] == "lifespan.startup.failed":
exception = message.get("message", "")
else:
raise ValueError(f"unexpected type: {message['type']}")
startup_done.set()

elif shutdown_sent.is_set() and not shutdown_done.is_set():
if message["type"] == "lifespan.shutdown.complete":
pass
elif message["type"] == "lifespan.shutdown.failed":
exception = message.get("message", "")
else:
raise ValueError(f"unexpected type: {message['type']}")
shutdown_done.set()

else:
raise RuntimeError("unexpected send")

# This wrapper is needed because TaskGroup.start_soon does not like that
# App returns Awaitable instead of Coroutine
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup that is the case. If we use this somewhere else in the library (I don't recall) it might be worth doing something more generic since the issue relates to awaitables/coroutines and not specifically ASGI apps.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense, for now this wrapper turned out to be a good thing because I added some code at the end to check if the app did not return in a situation where that would block the context manager.

async def coro_app(scope: Scope, receive: Receive, send: Send) -> None:
await app(scope, receive, send)

try:
async with create_task_group() as tg:
tg.start_soon(coro_app, {**scope, "app": app}, receive, send)
Comment on lines +643 to +644
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After all of the issues caused by using tasks in BaseHTTPMiddelware I'd like to avoid using them if possible. And I think that in this case it is possible: https://github.com/adriangb/asgi-routing/blob/main/asgi_routing/_lifespan_dispatcher.py#L11-L101. It's not really more LOC either. But maybe there's a bug in there that I didn't catch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the main thing is that you need a task to be able to wrap a lifespan in an async context manager, which does feel like a nice abstraction to have. But I definitely understand your reservations.

await startup_done.wait()
if exception:
raise LifespanException(exception)
try:
yield
finally:
shutdown_init.set()
await shutdown_done.wait()
if exception:
raise LifespanException(exception)
except Exception:
if lifespan_supported:
raise
else:
yield


class Router:
def __init__(
self,
Expand Down Expand Up @@ -659,6 +746,14 @@ async def shutdown(self) -> None:
else:
handler()

@asynccontextmanager
async def mount_lifespans(self, scope: Scope) -> typing.AsyncIterator[None]:
async with AsyncExitStack() as stack:
for route in self.routes:
if isinstance(route, Mount):
await stack.enter_async_context(_app_lifespan(scope, route.app))
yield

async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Handle ASGI lifespan messages, which allows us to manage application
Expand All @@ -668,16 +763,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
app = scope.get("app")
await receive()
try:
async with self.lifespan_context(app):
async with self.lifespan_context(app), self.mount_lifespans(scope):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
exc_text = traceback.format_exc()
except BaseException as e:
if isinstance(e, LifespanException):
message = e.message
else:
message = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
await send({"type": "lifespan.shutdown.failed", "message": message})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
await send({"type": "lifespan.startup.failed", "message": message})
raise
else:
await send({"type": "lifespan.shutdown.complete"})
Expand Down