Skip to content

Commit

Permalink
Run lifespans of mounted applications
Browse files Browse the repository at this point in the history
  • Loading branch information
daanvdk committed Dec 23, 2022
1 parent d755851 commit c402c22
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 8 deletions.
2 changes: 2 additions & 0 deletions starlette/responses.py
Expand Up @@ -157,6 +157,8 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] == "http"

await send(
{
"type": "http.response.start",
Expand Down
112 changes: 105 additions & 7 deletions starlette/routing.py
Expand Up @@ -6,9 +6,11 @@
import types
import typing
import warnings
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum

from anyio import Event, create_task_group

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
Expand All @@ -17,7 +19,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -572,6 +574,91 @@ def __call__(self: _T, app: object) -> _T:
return self


class LifespanException(Exception):
def __init__(self, message: str):
self.message = message


@asynccontextmanager
async def _app_lifespan(scope: Scope, app: ASGIApp) -> typing.AsyncIterator[None]:
startup_sent = Event()
startup_done = Event()

shutdown_init = Event()
shutdown_sent = Event()
shutdown_done = Event()

lifespan_supported = False
exception = None

async def receive() -> Message:
nonlocal lifespan_supported

lifespan_supported = True

if not startup_sent.is_set():
startup_sent.set()
return {"type": "lifespan.startup"}

elif startup_done.is_set() and not shutdown_sent.is_set():
await shutdown_init.wait()
shutdown_sent.set()
return {"type": "lifespan.shutdown"}

else:
raise RuntimeError("unexpected receive")

async def send(message: Message) -> None:
nonlocal exception, lifespan_supported

lifespan_supported = True

if startup_sent.is_set() and not startup_done.is_set():
if message["type"] == "lifespan.startup.complete":
pass
elif message["type"] == "lifespan.startup.failed":
exception = message.get("message", "")
else:
raise ValueError(f"unexpected type: {message['type']}")
startup_done.set()

elif shutdown_sent.is_set() and not shutdown_done.is_set():
if message["type"] == "lifespan.shutdown.complete":
pass
elif message["type"] == "lifespan.shutdown.failed":
exception = message.get("message", "")
else:
raise ValueError(f"unexpected type: {message['type']}")
shutdown_done.set()

else:
raise RuntimeError("unexpected send")

# This wrapper is needed because TaskGroup.start_soon does not like that
# App returns Awaitable instead of Coroutine
async def coro_app(scope: Scope, receive: Receive, send: Send) -> None:
await app(scope, receive, send)

try:
async with create_task_group() as tg:
tg.start_soon(coro_app, {**scope, "app": app}, receive, send)
await startup_done.wait()
if exception:
raise LifespanException(exception)
try:
yield
finally:
shutdown_init.set()
await shutdown_done.wait()
if exception:
raise LifespanException(exception)
except Exception:
if lifespan_supported:
raise
else:
yield


class Router:
def __init__(
self,
Expand Down Expand Up @@ -659,6 +746,14 @@ async def shutdown(self) -> None:
else:
handler()

@asynccontextmanager
async def mount_lifespans(self, scope: Scope) -> typing.AsyncIterator[None]:
async with AsyncExitStack() as stack:
for route in self.routes:
if isinstance(route, Mount):
await stack.enter_async_context(_app_lifespan(scope, route.app))
yield

async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Handle ASGI lifespan messages, which allows us to manage application
Expand All @@ -668,16 +763,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
app = scope.get("app")
await receive()
try:
async with self.lifespan_context(app):
async with self.lifespan_context(app), self.mount_lifespans(scope):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
exc_text = traceback.format_exc()
except BaseException as e:
if isinstance(e, LifespanException):
message = e.message
else:
message = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
await send({"type": "lifespan.shutdown.failed", "message": message})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
await send({"type": "lifespan.startup.failed", "message": message})
raise
else:
await send({"type": "lifespan.shutdown.complete"})
Expand Down

0 comments on commit c402c22

Please sign in to comment.