Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use an async context manager factory for lifespan #1227

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -40,6 +40,7 @@ def get_long_description():
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
graingert marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require={
"full": [
Expand Down
2 changes: 1 addition & 1 deletion starlette/applications.py
Expand Up @@ -46,7 +46,7 @@ def __init__(
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
Expand Down
109 changes: 88 additions & 21 deletions starlette/routing.py
@@ -1,9 +1,13 @@
import asyncio
import contextlib
import functools
import inspect
import re
import sys
import traceback
import types
import typing
import warnings
from enum import Enum

from starlette.concurrency import run_in_threadpool
Expand All @@ -15,6 +19,11 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover


class NoMatchFound(Exception):
"""
Expand Down Expand Up @@ -470,6 +479,51 @@ def __eq__(self, other: typing.Any) -> bool:
)


_T = typing.TypeVar("_T")


class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm

async def __aenter__(self) -> _T:
return self._cm.__enter__()

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)


def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
cmgr = contextlib.contextmanager(lifespan_context)

@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return _AsyncLiftContextManager(cmgr(app))

return wrapper


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup()

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
return self


class Router:
def __init__(
self,
Expand All @@ -478,20 +532,39 @@ def __init__(
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)

async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
await self.startup()
yield
await self.shutdown()
if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)

self.lifespan_context = default_lifespan if lifespan is None else lifespan
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now read this like it should be "a contextlib..."

Suggested change
"use an @contextlib.asynccontextmanager function instead",
"use a @contextlib.asynccontextmanager function instead",

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read this as an at contextlib.asynccontextmanager

DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
graingert marked this conversation as resolved.
Show resolved Hide resolved
lifespan, # type: ignore[arg-type]
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
)
else:
self.lifespan_context = lifespan

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
Expand Down Expand Up @@ -541,25 +614,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
started = False
app = scope.get("app")
await receive()
try:
if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
Expand Down
28 changes: 20 additions & 8 deletions starlette/testclient.py
Expand Up @@ -543,22 +543,34 @@ async def lifespan(self) -> None:

async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
message = await self.stream_send.receive()
if message is None:
self.task.result()

async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
await receive()

async def wait_shutdown(self) -> None:
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message

async def wait_shutdown(self) -> None:
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] == "lifespan.shutdown.complete"
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()
43 changes: 41 additions & 2 deletions tests/test_applications.py
@@ -1,4 +1,5 @@
import os
import sys

import pytest

Expand All @@ -10,6 +11,11 @@
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover

app = Starlette()


Expand Down Expand Up @@ -286,7 +292,39 @@ def run_cleanup():
assert cleanup_complete


def test_app_async_lifespan(test_client_factory):
def test_app_async_cm_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

@asynccontextmanager
async def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True

app = Starlette(lifespan=lifespan)

assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


deprecated_lifespan = pytest.mark.filterwarnings(
r"ignore"
r":(async )?generator function lifespans are deprecated, use an "
r"@contextlib\.asynccontextmanager function instead"
r":DeprecationWarning"
r":starlette.routing"
)


@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand All @@ -307,7 +345,8 @@ async def lifespan(app):
assert cleanup_complete


def test_app_sync_lifespan(test_client_factory):
@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

Expand Down
11 changes: 7 additions & 4 deletions tests/test_testclient.py
Expand Up @@ -12,10 +12,12 @@
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect

if sys.version_info >= (3, 7):
from asyncio import current_task as asyncio_current_task # pragma: no cover
else:
asyncio_current_task = asyncio.Task.current_task # pragma: no cover
if sys.version_info >= (3, 7): # pragma: no cover
from asyncio import current_task as asyncio_current_task
from contextlib import asynccontextmanager
else: # pragma: no cover
asyncio_current_task = asyncio.Task.current_task
from contextlib2 import asynccontextmanager

mock_service = Starlette()

Expand Down Expand Up @@ -90,6 +92,7 @@ def get_identity():
shutdown_task = object()
shutdown_loop = None

@asynccontextmanager
async def lifespan_context(app):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop

Expand Down