diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index d0089c0c9..e9a36e3b6 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -1,6 +1,7 @@
import asyncio
import typing
+from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -45,20 +46,41 @@ async def coro() -> None:
task.result()
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
+ status = message["status"]
+ headers = message["headers"]
+
+ first_body_message = await queue.get()
+ if first_body_message is None:
+ task.result()
+ raise RuntimeError("Empty response body returned")
+ assert first_body_message["type"] == "http.response.body"
+ response_body_start = first_body_message.get("body", b"")
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
- while True:
+ # In non-streaming responses, there should be one message to emit
+ yield response_body_start
+ message = first_body_message
+ while message and message.get("more_body"):
message = await queue.get()
if message is None:
break
assert message["type"] == "http.response.body"
yield message.get("body", b"")
- task.result()
- response = StreamingResponse(
- status_code=message["status"], content=body_stream()
+ if task.done():
+ # Check for exceptions and raise if present.
+ # Incomplete tasks may still have background tasks to run.
+ task.result()
+
+ # Assume non-streaming and start with a regular response
+ response: typing.Union[Response, StreamingResponse] = Response(
+ status_code=status, content=response_body_start
)
- response.raw_headers = message["headers"]
+
+ if first_body_message.get("more_body"):
+ response = StreamingResponse(status_code=status, content=body_stream())
+
+ response.raw_headers = headers
return response
async def dispatch(
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 048dd9ffb..d488acf81 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -1,9 +1,13 @@
+import asyncio
+
+import aiofiles
import pytest
from starlette.applications import Starlette
+from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse
+from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.testclient import TestClient
@@ -143,3 +147,158 @@ def homepage(request):
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+
+def test_custom_middleware_streaming(tmp_path):
+ """
+ Ensure that a StreamingResponse completes successfully with BaseHTTPMiddleware
+ """
+
+ @app.route("/streaming")
+ async def some_streaming(_):
+ async def numbers_stream():
+ """
+ Should produce something like:
+
+ """
+ yield ("")
+ for number in range(1, 4):
+ yield "- %d
" % number
+ yield ("
")
+
+ return StreamingResponse(numbers_stream())
+
+ client = TestClient(app)
+ response = client.get("/streaming")
+ assert response.headers["Custom-Header"] == "Example"
+ assert (
+ response.text
+ == ""
+ )
+
+
+def test_custom_middleware_streaming_exception_on_start():
+ """
+ Ensure that BaseHTTPMiddleware handles exceptions on response start
+ """
+
+ @app.route("/broken-streaming-on-start")
+ async def broken_stream_start(request):
+ async def broken():
+ raise ValueError("Oh no!")
+ yield 0 # pragma: no cover
+
+ return StreamingResponse(broken())
+
+ client = TestClient(app)
+ with pytest.raises(ValueError):
+ # right before body stream starts (only start message emitted)
+ # this should trigger _first_ message being None
+ response = client.get("/broken-streaming-on-start")
+
+
+def test_custom_middleware_streaming_exception_midstream():
+ """
+ Ensure that BaseHTTPMiddleware handles exceptions after streaming has started
+ """
+
+ @app.route("/broken-streaming-midstream")
+ async def broken_stream_midstream(request):
+ async def broken():
+ yield ("")
+ for number in range(1, 3):
+ yield "- %d
" % number
+ if number >= 2:
+ raise RuntimeError("This is a broken stream")
+
+ return StreamingResponse(broken())
+
+ client = TestClient(app)
+ with pytest.raises(RuntimeError):
+ # after body streaming has started
+ response = client.get("/broken-streaming-midstream")
+
+
+def test_custom_middleware_streaming_background(tmp_path):
+ """
+ Ensure that BaseHTTPMiddleware with a StreamingResponse runs BackgroundTasks after response.
+
+ This test writes to a temporary file
+ """
+
+ @app.route("/background-after-streaming")
+ async def background_after_streaming(request):
+ filepath = request.query_params["filepath"]
+
+ async def background():
+ await asyncio.sleep(1)
+ async with aiofiles.open(filepath, mode="w") as fl: # pragma: no cover
+ await fl.write("background last")
+
+ async def numbers_stream():
+ async with aiofiles.open(filepath, mode="w") as fl:
+ await fl.write("handler first")
+ for number in range(1, 4):
+ yield "%d\n" % number
+
+ return StreamingResponse(
+ numbers_stream(), background=BackgroundTask(background)
+ )
+
+ client = TestClient(app)
+
+ # Set up a file to track whether background has run
+ filepath = tmp_path / "background_test.txt"
+ filepath.write_text("Test Start")
+
+ response = client.get("/background-after-streaming?filepath={}".format(filepath))
+ assert response.headers["Custom-Header"] == "Example"
+ assert response.text == "1\n2\n3\n"
+ with filepath.open() as fl:
+ # background should not have run yet
+ assert fl.read() == "handler first"
+
+
+class Custom404Middleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ resp = await call_next(request)
+ if resp.status_code == 404:
+ return PlainTextResponse("Oh no!")
+ return resp
+
+
+def test_custom_middleware_pending_tasks(tmp_path):
+ """
+ Ensure that tasks are not pending left due to call_next method
+ """
+ app.add_middleware(Custom404Middleware)
+
+ @app.route("/trivial")
+ async def trivial(_):
+ return PlainTextResponse("Working")
+
+ @app.route("/streaming_task_count")
+ async def some_streaming(_):
+ async def numbers_stream():
+ for number in range(1, 4):
+ yield "%d\n" % number
+
+ return StreamingResponse(numbers_stream())
+
+ client = TestClient(app)
+ task_count = lambda: len(asyncio.Task.all_tasks())
+ # Task_count after issuing requests must not grow
+ assert task_count() == 1
+ response = client.get("/missing")
+ assert task_count() <= 2
+ response = client.get("/missing")
+ assert task_count() <= 2
+ response = client.get("/trivial")
+ assert task_count() <= 2
+ response = client.get("/streaming_task_count")
+ assert response.text == "1\n2\n3\n"
+ assert task_count() <= 2
+ response = client.get("/missing")
+ assert task_count() <= 2
+ response = client.get("/trivial")
+ assert response.text == "Working"