From 33744868727e1e412a42649105557ae5f7f5f592 Mon Sep 17 00:00:00 2001 From: Oleksandr Fedorov Date: Tue, 25 Jan 2022 17:42:30 +0200 Subject: [PATCH 1/3] Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse --- starlette/middleware/base.py | 5 +++++ tests/middleware/test_base.py | 16 +++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 423f40777..f7cf1cce1 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -47,11 +47,16 @@ async def coro() -> None: assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: + nonlocal app_exc + async with recv_stream: async for message in recv_stream: assert message["type"] == "http.response.body" yield message.get("body", b"") + if app_exc is not None: + raise app_exc + response = StreamingResponse( status_code=message["status"], content=body_stream() ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index c6bfd49fc..8cba030ca 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -3,7 +3,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Route @@ -28,6 +28,16 @@ def exc(request): raise Exception("Exc") +@app.route("/exc-stream") +def exc_stream(request): + return StreamingResponse(_generate_faulty_stream()) + + +def _generate_faulty_stream(): + yield b"Ok" + raise Exception("Faulty Stream") + + @app.route("/no-response") class NoResponse: def __init__(self, scope, receive, send): @@ -56,6 +66,10 @@ def test_custom_middleware(test_client_factory): response = client.get("/exc") assert str(ctx.value) == "Exc" + with pytest.raises(Exception) as ctx: + response = client.get("/exc-stream") + assert str(ctx.value) == "Faulty Stream" + with pytest.raises(RuntimeError): response = client.get("/no-response") From 9109de57c31c3ff76ec71d7de275a12840a12d19 Mon Sep 17 00:00:00 2001 From: Oleksandr Fedorov Date: Mon, 31 Jan 2022 11:03:53 +0200 Subject: [PATCH 2/3] Apply notes from PR: * remove `nonlocal app_exc`; * add extra test. --- starlette/middleware/base.py | 2 -- tests/middleware/test_base.py | 9 +++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f7cf1cce1..bfb4a54a4 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -47,8 +47,6 @@ async def coro() -> None: assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: - nonlocal app_exc - async with recv_stream: async for message in recv_stream: assert message["type"] == "http.response.body" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8cba030ca..10f7ea897 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -172,3 +172,12 @@ async def dispatch(self, request, call_next): client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" + +def test_exception_on_mounted_apps(test_client_factory): + sub_app = Starlette(routes=[Route("/", exc)]) + app.mount("/sub", sub_app) + + client = test_client_factory(app) + with pytest.raises(Exception) as ctx: + client.get("/sub/") + assert str(ctx.value) == "Exc" From 09d68387b22d4fbe7d027168dd257e79851ba2c1 Mon Sep 17 00:00:00 2001 From: Oleksandr Fedorov Date: Mon, 31 Jan 2022 11:06:50 +0200 Subject: [PATCH 3/3] Fix formatting --- tests/middleware/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 10f7ea897..32468dcd2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -173,6 +173,7 @@ async def dispatch(self, request, call_next): response = client.get("/does_not_exist") assert response.text == "Custom" + def test_exception_on_mounted_apps(test_client_factory): sub_app = Starlette(routes=[Route("/", exc)]) app.mount("/sub", sub_app)