Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use an async context manager factory for lifespan #1227

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -40,6 +40,7 @@ def get_long_description():
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
graingert marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require={
"full": [
Expand Down
2 changes: 1 addition & 1 deletion starlette/applications.py
Expand Up @@ -46,7 +46,7 @@ def __init__(
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
105 changes: 84 additions & 21 deletions starlette/routing.py
@@ -1,9 +1,13 @@
import asyncio
import contextlib
import functools
import inspect
import re
import sys
import traceback
import types
import typing
import warnings
from enum import Enum

from starlette.concurrency import run_in_threadpool
Expand All @@ -15,6 +19,11 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover


class NoMatchFound(Exception):
"""
Expand Down Expand Up @@ -470,6 +479,51 @@ def __eq__(self, other: typing.Any) -> bool:
)


_T = typing.TypeVar("_T")


class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm

async def __aenter__(self) -> _T:
return self._cm.__enter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)


def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
cmgr = contextlib.contextmanager(lifespan_context)

@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return _AsyncLiftContextManager(cmgr(app))

return wrapper


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup()

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
return self


class Router:
def __init__(
self,
Expand All @@ -478,20 +532,35 @@ def __init__(
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
await self.startup()
yield
await self.shutdown()
if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

self.lifespan_context = default_lifespan if lifespan is None else lifespan
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"lifespan must be an AsyncContextManager factory", DeprecationWarning
graingert marked this conversation as resolved.
Show resolved Hide resolved
)
self.lifespan_context = asynccontextmanager(
graingert marked this conversation as resolved.
Show resolved Hide resolved
lifespan, # type: ignore[arg-type]
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"lifespan must be an AsyncContextManager factory", DeprecationWarning
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
)
else:
self.lifespan_context = lifespan

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down Expand Up @@ -541,25 +610,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
started = False
app = scope.get("app")
await receive()
try:
if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
Expand Down
28 changes: 20 additions & 8 deletions starlette/testclient.py
Expand Up @@ -523,22 +523,34 @@ async def lifespan(self) -> None:

async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
message = await self.stream_send.receive()
if message is None:
self.task.result()

async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
await receive()

async def wait_shutdown(self) -> None:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

async def wait_shutdown(self) -> None:
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] == "lifespan.shutdown.complete"
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
42 changes: 40 additions & 2 deletions tests/test_applications.py
@@ -1,4 +1,5 @@
import os
import sys

import pytest

Expand All @@ -10,6 +11,11 @@
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover

app = Starlette()


Expand Down Expand Up @@ -286,7 +292,38 @@ def run_cleanup():
assert cleanup_complete


def test_app_async_lifespan(test_client_factory):
def test_app_async_cm_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

@asynccontextmanager
async def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True

app = Starlette(lifespan=lifespan)

assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


deprecated_lifespan = pytest.mark.filterwarnings(
"ignore"
":lifespan must be an AsyncContextManager factory"
":DeprecationWarning"
":starlette.routing"
)


@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand All @@ -307,7 +344,8 @@ async def lifespan(app):
assert cleanup_complete


def test_app_sync_lifespan(test_client_factory):
@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand Down