diff --git a/starlette/routing.py b/starlette/routing.py index cef1ef484..792b6abaa 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -541,29 +541,47 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ - first = True app = scope.get("app") await receive() + + async def as_aiter(gen: typing.Generator) -> typing.AsyncGenerator: + for val in gen: + yield val + + lifespan_gen = ( + self.lifespan_context(app) + if inspect.isasyncgenfunction(self.lifespan_context) + else as_aiter(self.lifespan_context(app)) # type: ignore + ) + + await self._lifespan_iter(lifespan_gen, receive, send) + + async def _lifespan_iter( + self, gen: typing.AsyncGenerator, receive: Receive, send: Send + ) -> None: try: - if inspect.isasyncgenfunction(self.lifespan_context): - async for item in self.lifespan_context(app): - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - else: - for item in self.lifespan_context(app): # type: ignore - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - except BaseException: - if first: - exc_text = traceback.format_exc() - await send({"type": "lifespan.startup.failed", "message": exc_text}) + await gen.asend(None) + except StopAsyncIteration: + raise RuntimeError("Lifespan context never yielded.") + except Exception: + exc_text = traceback.format_exc() + await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: + await send({"type": "lifespan.startup.complete"}) + await receive() + + try: + await gen.asend(None) + except StopAsyncIteration: await send({"type": "lifespan.shutdown.complete"}) + except Exception: + exc_text = traceback.format_exc() + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + raise + else: + await gen.aclose() + raise RuntimeError("Lifespan context yielded multiple times.") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ diff --git a/starlette/testclient.py b/starlette/testclient.py index c1c0fe165..54ce5fe5a 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -528,4 +528,11 @@ async def wait_shutdown(self) -> None: message = await self.stream_send.receive() if message is None: self.task.result() - assert message["type"] == "lifespan.shutdown.complete" + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + message = await self.stream_send.receive() + if message is None: + self.task.result() diff --git a/tests/test_routing.py b/tests/test_routing.py index 1d8eb8d95..d418ea389 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -528,6 +528,31 @@ async def run_shutdown(): assert shutdown_complete +def test_lifespan_context_async(): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + return PlainTextResponse("hello, world") + + async def lifespan(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + shutdown_complete = True + + app = Router(lifespan=lifespan, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_lifespan_sync(): startup_complete = False shutdown_complete = False @@ -559,6 +584,31 @@ def run_shutdown(): assert shutdown_complete +def test_lifespan_context_sync(): + startup_complete = False + shutdown_complete = False + + def hello_world(request): + return PlainTextResponse("hello, world") + + def lifespan(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + shutdown_complete = True + + app = Router(lifespan=lifespan, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_raise_on_startup(): def run_startup(): raise RuntimeError() @@ -585,11 +635,61 @@ def test_raise_on_shutdown(): def run_shutdown(): raise RuntimeError() - app = Router(on_shutdown=[run_shutdown]) + router = Router(on_shutdown=[run_shutdown]) + async def app(scope, receive, send): + async def _send(message): + nonlocal shutdown_failed + if message["type"] == "lifespan.shutdown.failed": + shutdown_failed = True + return await send(message) + + await router(scope, receive, _send) + + shutdown_failed = False with pytest.raises(RuntimeError): with TestClient(app): pass # pragma: nocover + assert shutdown_failed + + +def test_lifespan_yield_too_many(): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + return PlainTextResponse("hello, world") + + async def yield_too_many(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + yield + shutdown_complete = True # pragma: nocover + + app = Router(lifespan=yield_too_many, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with pytest.raises(RuntimeError, match="yielded multiple times"): + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + + assert startup_complete + assert not shutdown_complete + + +def test_lifespan_yield_too_few(): + async def yield_too_few(app): + for _ in []: + yield _ # pragma: nocover + + app = Router(lifespan=yield_too_few) + with pytest.raises(RuntimeError, match="never yielded"): + with TestClient(app): + pass # pragma: nocover class AsyncEndpointClassMethod: