From cf1dd01ab95cb108ef2411abf9c75b9e6b5b3142 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Jun 2021 13:32:52 +0100 Subject: [PATCH 1/4] Support lifespan.shutdown.failed message --- starlette/routing.py | 51 +++++++++++++++++++++++++++-------------- starlette/testclient.py | 9 +++++++- tests/test_routing.py | 13 ++++++++++- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index cef1ef484..89dcb0bf0 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -541,29 +541,46 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ - first = True app = scope.get("app") await receive() + + async def as_aiter(gen: typing.Generator) -> typing.AsyncGenerator: + for val in gen: + yield val + + lifespan_gen = ( + self.lifespan_context(app) + if inspect.isasyncgenfunction(self.lifespan_context) + else as_aiter(self.lifespan_context(app)) # type: ignore + ) + + await self._lifespan_iter(lifespan_gen, receive, send) + + async def _lifespan_iter( + self, gen: typing.AsyncGenerator, receive: Receive, send: Send + ) -> None: try: - if inspect.isasyncgenfunction(self.lifespan_context): - async for item in self.lifespan_context(app): - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - else: - for item in self.lifespan_context(app): # type: ignore - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - except BaseException: - if first: - exc_text = traceback.format_exc() - await send({"type": "lifespan.startup.failed", "message": exc_text}) + await gen.asend(None) + except StopAsyncIteration: + raise RuntimeError("Lifespan context never yielded.") + except Exception: + exc_text = traceback.format_exc() + await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: + await send({"type": "lifespan.startup.complete"}) + await receive() + + try: + await gen.asend(None) + except StopAsyncIteration: await send({"type": "lifespan.shutdown.complete"}) + except Exception: + exc_text = traceback.format_exc() + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + raise + else: + raise RuntimeError("Lifespan context yielded multiple times.") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ diff --git a/starlette/testclient.py b/starlette/testclient.py index c1c0fe165..54ce5fe5a 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -528,4 +528,11 @@ async def wait_shutdown(self) -> None: message = await self.stream_send.receive() if message is None: self.task.result() - assert message["type"] == "lifespan.shutdown.complete" + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + message = await self.stream_send.receive() + if message is None: + self.task.result() diff --git a/tests/test_routing.py b/tests/test_routing.py index 1d8eb8d95..b89e3b5eb 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -585,11 +585,22 @@ def test_raise_on_shutdown(): def run_shutdown(): raise RuntimeError() - app = Router(on_shutdown=[run_shutdown]) + router = Router(on_shutdown=[run_shutdown]) + async def app(scope, receive, send): + async def _send(message): + nonlocal shutdown_failed + if message["type"] == "lifespan.shutdown.failed": + shutdown_failed = True + return await send(message) + + await router(scope, receive, _send) + + shutdown_failed = False with pytest.raises(RuntimeError): with TestClient(app): pass # pragma: nocover + assert shutdown_failed class AsyncEndpointClassMethod: From d34f54238bfda8357b2a2b85f0759cb556992164 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Jun 2021 13:38:23 +0100 Subject: [PATCH 2/4] Add tests for lifespan contextmanagers --- tests/test_routing.py | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_routing.py b/tests/test_routing.py index b89e3b5eb..1be612aef 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -528,6 +528,31 @@ async def run_shutdown(): assert shutdown_complete +def test_lifespan_context_async(): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + return PlainTextResponse("hello, world") + + async def lifespan(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + shutdown_complete = True + + app = Router(lifespan=lifespan, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_lifespan_sync(): startup_complete = False shutdown_complete = False @@ -559,6 +584,31 @@ def run_shutdown(): assert shutdown_complete +def test_lifespan_context_sync(): + startup_complete = False + shutdown_complete = False + + def hello_world(request): + return PlainTextResponse("hello, world") + + def lifespan(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + shutdown_complete = True + + app = Router(lifespan=lifespan, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_raise_on_startup(): def run_startup(): raise RuntimeError() From 4225727038c647063ca8359bc97ee078ab57c8e7 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Jun 2021 13:52:31 +0100 Subject: [PATCH 3/4] Add tests for lifespan yielding wrong number of times --- tests/test_routing.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_routing.py b/tests/test_routing.py index 1be612aef..d418ea389 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -653,6 +653,45 @@ async def _send(message): assert shutdown_failed +def test_lifespan_yield_too_many(): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + return PlainTextResponse("hello, world") + + async def yield_too_many(app): + nonlocal startup_complete, shutdown_complete + startup_complete = True + yield + yield + shutdown_complete = True # pragma: nocover + + app = Router(lifespan=yield_too_many, routes=[Route("/", hello_world)]) + + assert not startup_complete + assert not shutdown_complete + with pytest.raises(RuntimeError, match="yielded multiple times"): + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + + assert startup_complete + assert not shutdown_complete + + +def test_lifespan_yield_too_few(): + async def yield_too_few(app): + for _ in []: + yield _ # pragma: nocover + + app = Router(lifespan=yield_too_few) + with pytest.raises(RuntimeError, match="never yielded"): + with TestClient(app): + pass # pragma: nocover + + class AsyncEndpointClassMethod: @classmethod async def async_endpoint(cls, arg, request): From 3648ce63839d8c7aadb0ac2740c9483c9157922e Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 19 Jun 2021 14:03:33 +0100 Subject: [PATCH 4/4] Close AsyncGenerator on too many yields --- starlette/routing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/starlette/routing.py b/starlette/routing.py index 89dcb0bf0..792b6abaa 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -580,6 +580,7 @@ async def _lifespan_iter( await send({"type": "lifespan.shutdown.failed", "message": exc_text}) raise else: + await gen.aclose() raise RuntimeError("Lifespan context yielded multiple times.") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: