Skip to content

Commit

Permalink
use an async context manager factory for lifespan (#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Jul 3, 2021
1 parent 254d0d9 commit 537ab6a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 36 deletions.
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -40,6 +40,7 @@ def get_long_description():
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
],
extras_require={
"full": [
Expand Down
2 changes: 1 addition & 1 deletion starlette/applications.py
Expand Up @@ -46,7 +46,7 @@ def __init__(
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
109 changes: 88 additions & 21 deletions starlette/routing.py
@@ -1,9 +1,13 @@
import asyncio
import contextlib
import functools
import inspect
import re
import sys
import traceback
import types
import typing
import warnings
from enum import Enum

from starlette.concurrency import run_in_threadpool
Expand All @@ -15,6 +19,11 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover


class NoMatchFound(Exception):
"""
Expand Down Expand Up @@ -470,6 +479,51 @@ def __eq__(self, other: typing.Any) -> bool:
)


_T = typing.TypeVar("_T")


class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm

async def __aenter__(self) -> _T:
return self._cm.__enter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)


def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
cmgr = contextlib.contextmanager(lifespan_context)

@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return _AsyncLiftContextManager(cmgr(app))

return wrapper


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup()

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
return self


class Router:
def __init__(
self,
Expand All @@ -478,20 +532,39 @@ def __init__(
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
await self.startup()
yield
await self.shutdown()
if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

self.lifespan_context = default_lifespan if lifespan is None else lifespan
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
lifespan, # type: ignore[arg-type]
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
)
else:
self.lifespan_context = lifespan

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down Expand Up @@ -541,25 +614,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
started = False
app = scope.get("app")
await receive()
try:
if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
Expand Down
28 changes: 20 additions & 8 deletions starlette/testclient.py
Expand Up @@ -543,22 +543,34 @@ async def lifespan(self) -> None:

async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
message = await self.stream_send.receive()
if message is None:
self.task.result()

async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
await receive()

async def wait_shutdown(self) -> None:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

async def wait_shutdown(self) -> None:
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] == "lifespan.shutdown.complete"
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
43 changes: 41 additions & 2 deletions tests/test_applications.py
@@ -1,4 +1,5 @@
import os
import sys

import pytest

Expand All @@ -10,6 +11,11 @@
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover

app = Starlette()


Expand Down Expand Up @@ -286,7 +292,39 @@ def run_cleanup():
assert cleanup_complete


def test_app_async_lifespan(test_client_factory):
def test_app_async_cm_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

@asynccontextmanager
async def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True

app = Starlette(lifespan=lifespan)

assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


deprecated_lifespan = pytest.mark.filterwarnings(
r"ignore"
r":(async )?generator function lifespans are deprecated, use an "
r"@contextlib\.asynccontextmanager function instead"
r":DeprecationWarning"
r":starlette.routing"
)


@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand All @@ -307,7 +345,8 @@ async def lifespan(app):
assert cleanup_complete


def test_app_sync_lifespan(test_client_factory):
@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand Down
11 changes: 7 additions & 4 deletions tests/test_testclient.py
Expand Up @@ -12,10 +12,12 @@
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect

if sys.version_info >= (3, 7):
from asyncio import current_task as asyncio_current_task # pragma: no cover
else:
asyncio_current_task = asyncio.Task.current_task # pragma: no cover
if sys.version_info >= (3, 7): # pragma: no cover
from asyncio import current_task as asyncio_current_task
from contextlib import asynccontextmanager
else: # pragma: no cover
asyncio_current_task = asyncio.Task.current_task
from contextlib2 import asynccontextmanager

mock_service = Starlette()

Expand Down Expand Up @@ -90,6 +92,7 @@ def get_identity():
shutdown_task = object()
shutdown_loop = None

@asynccontextmanager
async def lifespan_context(app):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop

Expand Down

0 comments on commit 537ab6a

Please sign in to comment.