From aff548ab18bb5661d28503a98ab68737777a0b18 Mon Sep 17 00:00:00 2001 From: Oleksandr Fedorov <2283679+o-fedorov@users.noreply.github.com> Date: Mon, 31 Jan 2022 12:12:15 +0200 Subject: [PATCH] Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse (#1459) * Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse * Apply notes from PR: * remove `nonlocal app_exc`; * add extra test. * Fix formatting Co-authored-by: Marcelo Trylesinski --- starlette/middleware/base.py | 3 +++ tests/middleware/test_base.py | 26 +++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 423f40777..bfb4a54a4 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -52,6 +52,9 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: assert message["type"] == "http.response.body" yield message.get("body", b"") + if app_exc is not None: + raise app_exc + response = StreamingResponse( status_code=message["status"], content=body_stream() ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index c6bfd49fc..32468dcd2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -3,7 +3,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Route @@ -28,6 +28,16 @@ def exc(request): raise Exception("Exc") +@app.route("/exc-stream") +def exc_stream(request): + return StreamingResponse(_generate_faulty_stream()) + + +def _generate_faulty_stream(): + yield b"Ok" + raise Exception("Faulty Stream") + + @app.route("/no-response") class NoResponse: def __init__(self, scope, receive, send): @@ -56,6 +66,10 @@ def test_custom_middleware(test_client_factory): response = client.get("/exc") assert str(ctx.value) == "Exc" + with pytest.raises(Exception) as ctx: + response = client.get("/exc-stream") + assert str(ctx.value) == "Faulty Stream" + with pytest.raises(RuntimeError): response = client.get("/no-response") @@ -158,3 +172,13 @@ async def dispatch(self, request, call_next): client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" + + +def test_exception_on_mounted_apps(test_client_factory): + sub_app = Starlette(routes=[Route("/", exc)]) + app.mount("/sub", sub_app) + + client = test_client_factory(app) + with pytest.raises(Exception) as ctx: + client.get("/sub/") + assert str(ctx.value) == "Exc"