diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index d0089c0c9..e9a36e3b6 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,7 @@ import asyncio import typing +from starlette.background import BackgroundTask from starlette.requests import Request from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -45,20 +46,41 @@ async def coro() -> None: task.result() raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" + status = message["status"] + headers = message["headers"] + + first_body_message = await queue.get() + if first_body_message is None: + task.result() + raise RuntimeError("Empty response body returned") + assert first_body_message["type"] == "http.response.body" + response_body_start = first_body_message.get("body", b"") async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: + # In non-streaming responses, there should be one message to emit + yield response_body_start + message = first_body_message + while message and message.get("more_body"): message = await queue.get() if message is None: break assert message["type"] == "http.response.body" yield message.get("body", b"") - task.result() - response = StreamingResponse( - status_code=message["status"], content=body_stream() + if task.done(): + # Check for exceptions and raise if present. + # Incomplete tasks may still have background tasks to run. + task.result() + + # Assume non-streaming and start with a regular response + response: typing.Union[Response, StreamingResponse] = Response( + status_code=status, content=response_body_start ) - response.raw_headers = message["headers"] + + if first_body_message.get("more_body"): + response = StreamingResponse(status_code=status, content=body_stream()) + + response.raw_headers = headers return response async def dispatch( diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 048dd9ffb..d488acf81 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,9 +1,13 @@ +import asyncio + +import aiofiles import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Route from starlette.testclient import TestClient @@ -143,3 +147,158 @@ def homepage(request): def test_middleware_repr(): middleware = Middleware(CustomMiddleware) assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_custom_middleware_streaming(tmp_path): + """ + Ensure that a StreamingResponse completes successfully with BaseHTTPMiddleware + """ + + @app.route("/streaming") + async def some_streaming(_): + async def numbers_stream(): + """ + Should produce something like: + + """ + yield ("") + + return StreamingResponse(numbers_stream()) + + client = TestClient(app) + response = client.get("/streaming") + assert response.headers["Custom-Header"] == "Example" + assert ( + response.text + == "" + ) + + +def test_custom_middleware_streaming_exception_on_start(): + """ + Ensure that BaseHTTPMiddleware handles exceptions on response start + """ + + @app.route("/broken-streaming-on-start") + async def broken_stream_start(request): + async def broken(): + raise ValueError("Oh no!") + yield 0 # pragma: no cover + + return StreamingResponse(broken()) + + client = TestClient(app) + with pytest.raises(ValueError): + # right before body stream starts (only start message emitted) + # this should trigger _first_ message being None + response = client.get("/broken-streaming-on-start") + + +def test_custom_middleware_streaming_exception_midstream(): + """ + Ensure that BaseHTTPMiddleware handles exceptions after streaming has started + """ + + @app.route("/broken-streaming-midstream") + async def broken_stream_midstream(request): + async def broken(): + yield ("