From 58defbc240cab68e57e52385346eed0a395fd580 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Wed, 11 Aug 2021 12:15:30 +0200 Subject: [PATCH 1/2] Prevent ExceptionGroup in error views under a BaseHTTPMiddleware --- starlette/middleware/base.py | 4 ++++ tests/middleware/test_base.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 77ba66925..bf337b8e9 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -34,6 +34,10 @@ async def coro() -> None: try: message = await recv_stream.receive() except anyio.EndOfStream: + # HACK: give anyio a chance to surface any app exception first, + # in order to avoid an `anyio.ExceptionGroup`. + # See #1255. + await anyio.lowlevel.checkpoint() raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8a8df4ea6..c6bfd49fc 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -25,7 +25,7 @@ def homepage(request): @app.route("/exc") def exc(request): - raise Exception() + raise Exception("Exc") @app.route("/no-response") @@ -52,8 +52,9 @@ def test_custom_middleware(test_client_factory): response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception): + with pytest.raises(Exception) as ctx: response = client.get("/exc") + assert str(ctx.value) == "Exc" with pytest.raises(RuntimeError): response = client.get("/no-response") From 18a720931fc2528042ab8e0138fd1b0228c62b73 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 22 Oct 2021 14:30:10 +0200 Subject: [PATCH 2/2] Apply suggestion from @uSpike --- starlette/middleware/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index bf337b8e9..423f40777 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -23,21 +23,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return async def call_next(request: Request) -> Response: + app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() async def coro() -> None: + nonlocal app_exc + async with send_stream: - await self.app(scope, request.receive, send_stream.send) + try: + await self.app(scope, request.receive, send_stream.send) + except Exception as exc: + app_exc = exc task_group.start_soon(coro) try: message = await recv_stream.receive() except anyio.EndOfStream: - # HACK: give anyio a chance to surface any app exception first, - # in order to avoid an `anyio.ExceptionGroup`. - # See #1255. - await anyio.lowlevel.checkpoint() + if app_exc is not None: + raise app_exc raise RuntimeError("No response returned.") assert message["type"] == "http.response.start"