From 6c9bc55af7d9d202c2bc1299c9a65e844640c197 Mon Sep 17 00:00:00 2001 From: Florimond Manca Date: Thu, 28 Oct 2021 18:50:33 +0200 Subject: [PATCH] Prevent anyio.ExceptionGroup in error views under a BaseHTTPMiddleware (#1262) * Prevent ExceptionGroup in error views under a BaseHTTPMiddleware * Apply suggestion from @uSpike Co-authored-by: euri10 --- starlette/middleware/base.py | 10 +++++++++- tests/middleware/test_base.py | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 77ba66925..423f40777 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -23,17 +23,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return async def call_next(request: Request) -> Response: + app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() async def coro() -> None: + nonlocal app_exc + async with send_stream: - await self.app(scope, request.receive, send_stream.send) + try: + await self.app(scope, request.receive, send_stream.send) + except Exception as exc: + app_exc = exc task_group.start_soon(coro) try: message = await recv_stream.receive() except anyio.EndOfStream: + if app_exc is not None: + raise app_exc raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8a8df4ea6..c6bfd49fc 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -25,7 +25,7 @@ def homepage(request): @app.route("/exc") def exc(request): - raise Exception() + raise Exception("Exc") @app.route("/no-response") @@ -52,8 +52,9 @@ def test_custom_middleware(test_client_factory): response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception): + with pytest.raises(Exception) as ctx: response = client.get("/exc") + assert str(ctx.value) == "Exc" with pytest.raises(RuntimeError): response = client.get("/no-response")