Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use an async context manager factory for lifespan #1227

Merged
Merged
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -40,6 +40,7 @@ def get_long_description():
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"contextlib2; python_version < '3.10'",
graingert marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require={
"full": [
Expand Down
2 changes: 1 addition & 1 deletion starlette/applications.py
Expand Up @@ -46,7 +46,7 @@ def __init__(
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
105 changes: 84 additions & 21 deletions starlette/routing.py
@@ -1,9 +1,12 @@
import asyncio
import contextlib
import functools
import inspect
import re
import sys
import traceback
import typing
import warnings
from enum import Enum

from starlette.concurrency import run_in_threadpool
Expand All @@ -15,6 +18,16 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager
else:
from contextlib2 import asynccontextmanager

if sys.version_info >= (3, 10):
from contextlib import aclosing
else:
from contextlib2 import aclosing


class NoMatchFound(Exception):
"""
Expand Down Expand Up @@ -470,6 +483,54 @@ def __eq__(self, other: typing.Any) -> bool:
)


def _wrap_agen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.AsyncGenerator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
@functools.wraps(lifespan_context)
@asynccontextmanager
async def agen_wrapper(
app: typing.Any,
) -> typing.AsyncGenerator[None, None]:
async with aclosing(lifespan_context(app)) as agen: # type: ignore
async for _ in agen:
yield

return agen_wrapper


def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
@functools.wraps(lifespan_context)
@asynccontextmanager
async def gen_wrapper(
app: typing.Any,
) -> typing.AsyncGenerator[None, None]:
with contextlib.closing(lifespan_context(app)) as gen:
for _ in gen:
yield

return gen_wrapper


def _wrap_lifespan_context(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this stuff is a bit of a huge mess I think it's better to just make this a breaking change and make people add contextlib.asynccontextmanager to their functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomchristie are you happy for this to be a breaking change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what would be necessary for a user to convert their existing lifespan generator to an asynccontextmanager? i.e. an example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You just add the @contextlib.asynccontextmanager decorator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JayH5 thank you so much for this comment! Because of this I worked out I could get rid of all my ugly _wrap_lifespan_context stuff and just use asynccontextmanager(lifespan_context)

lifespan_context: typing.Union[
typing.Callable[[typing.Any], typing.AsyncGenerator],
typing.Callable[[typing.Any], typing.Generator],
typing.Callable[[typing.Any], typing.AsyncContextManager],
]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
if inspect.isasyncgenfunction(lifespan_context):
warnings.warn("lifespan must be an AsyncContextManager factory")
return _wrap_agen_lifespan_context(lifespan_context) # type: ignore[arg-type]

if inspect.isgeneratorfunction(lifespan_context):
warnings.warn("lifespan must be an AsyncContextManager factory")
return _wrap_gen_lifespan_context(lifespan_context) # type: ignore[arg-type]

return lifespan_context # type: ignore


class Router:
def __init__(
self,
Expand All @@ -478,20 +539,29 @@ def __init__(
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

@asynccontextmanager
async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
await self.startup()
yield
await self.shutdown()

self.lifespan_context = default_lifespan if lifespan is None else lifespan
try:
yield
finally:
await self.shutdown()

self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = (
default_lifespan # type: ignore
if lifespan is None
else _wrap_lifespan_context(lifespan)
)

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down Expand Up @@ -541,25 +611,18 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
started = False
app = scope.get("app")
await receive()
try:
if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
Expand Down