Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming with BaseHTTPMiddleware: force background to run after response completes #1017

Closed
wants to merge 11 commits into from
32 changes: 27 additions & 5 deletions starlette/middleware/base.py
@@ -1,6 +1,7 @@
import asyncio
import typing

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -45,20 +46,41 @@ async def coro() -> None:
task.result()
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
status = message["status"]
headers = message["headers"]

first_body_message = await queue.get()
if first_body_message is None:
task.result()
raise RuntimeError("Empty response body returned")
assert first_body_message["type"] == "http.response.body"
response_body_start = first_body_message.get("body", b"")

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
while True:
# In non-streaming responses, there should be one message to emit
yield response_body_start
message = first_body_message
while message and message.get("more_body"):
message = await queue.get()
if message is None:
break
assert message["type"] == "http.response.body"
yield message.get("body", b"")
task.result()

response = StreamingResponse(
status_code=message["status"], content=body_stream()
if task.done():
# Check for exceptions and raise if present.
# Incomplete tasks may still have background tasks to run.
task.result()

# Assume non-streaming and start with a regular response
response: typing.Union[Response, StreamingResponse] = Response(
status_code=status, content=response_body_start
)
response.raw_headers = message["headers"]

if first_body_message.get("more_body"):
response = StreamingResponse(status_code=status, content=body_stream())

response.raw_headers = headers
return response

async def dispatch(
Expand Down
161 changes: 160 additions & 1 deletion tests/middleware/test_base.py
@@ -1,9 +1,13 @@
import asyncio

import aiofiles
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.testclient import TestClient

Expand Down Expand Up @@ -143,3 +147,158 @@ def homepage(request):
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"


def test_custom_middleware_streaming(tmp_path):
"""
Ensure that a StreamingResponse completes successfully with BaseHTTPMiddleware
"""

@app.route("/streaming")
async def some_streaming(_):
async def numbers_stream():
"""
Should produce something like:
<html><body><ul><li>1...</li></ul></body></html>
"""
yield ("<html><body><ul>")
for number in range(1, 4):
yield "<li>%d</li>" % number
yield ("</ul></body></html>")

return StreamingResponse(numbers_stream())

client = TestClient(app)
response = client.get("/streaming")
assert response.headers["Custom-Header"] == "Example"
assert (
response.text
== "<html><body><ul><li>1</li><li>2</li><li>3</li></ul></body></html>"
)


def test_custom_middleware_streaming_exception_on_start():
"""
Ensure that BaseHTTPMiddleware handles exceptions on response start
"""

@app.route("/broken-streaming-on-start")
async def broken_stream_start(request):
async def broken():
raise ValueError("Oh no!")
yield 0 # pragma: no cover

return StreamingResponse(broken())

client = TestClient(app)
with pytest.raises(ValueError):
# right before body stream starts (only start message emitted)
# this should trigger _first_ message being None
response = client.get("/broken-streaming-on-start")


def test_custom_middleware_streaming_exception_midstream():
"""
Ensure that BaseHTTPMiddleware handles exceptions after streaming has started
"""

@app.route("/broken-streaming-midstream")
async def broken_stream_midstream(request):
async def broken():
yield ("<html><body><ul>")
for number in range(1, 3):
yield "<li>%d</li>" % number
if number >= 2:
raise RuntimeError("This is a broken stream")

return StreamingResponse(broken())

client = TestClient(app)
with pytest.raises(RuntimeError):
# after body streaming has started
response = client.get("/broken-streaming-midstream")


def test_custom_middleware_streaming_background(tmp_path):
"""
Ensure that BaseHTTPMiddleware with a StreamingResponse runs BackgroundTasks after response.

This test writes to a temporary file
"""

@app.route("/background-after-streaming")
async def background_after_streaming(request):
filepath = request.query_params["filepath"]

async def background():
await asyncio.sleep(1)
async with aiofiles.open(filepath, mode="w") as fl: # pragma: no cover
await fl.write("background last")

async def numbers_stream():
async with aiofiles.open(filepath, mode="w") as fl:
await fl.write("handler first")
for number in range(1, 4):
yield "%d\n" % number

return StreamingResponse(
numbers_stream(), background=BackgroundTask(background)
)

client = TestClient(app)

# Set up a file to track whether background has run
filepath = tmp_path / "background_test.txt"
filepath.write_text("Test Start")

response = client.get("/background-after-streaming?filepath={}".format(filepath))
assert response.headers["Custom-Header"] == "Example"
assert response.text == "1\n2\n3\n"
with filepath.open() as fl:
# background should not have run yet
assert fl.read() == "handler first"
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


class Custom404Middleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
resp = await call_next(request)
if resp.status_code == 404:
return PlainTextResponse("Oh no!")
return resp


def test_custom_middleware_pending_tasks(tmp_path):
"""
Ensure that tasks are not pending left due to call_next method
"""
app.add_middleware(Custom404Middleware)

@app.route("/trivial")
async def trivial(_):
return PlainTextResponse("Working")

@app.route("/streaming_task_count")
async def some_streaming(_):
async def numbers_stream():
for number in range(1, 4):
yield "%d\n" % number

return StreamingResponse(numbers_stream())

client = TestClient(app)
task_count = lambda: len(asyncio.Task.all_tasks())
# Task_count after issuing requests must not grow
assert task_count() == 1
response = client.get("/missing")
assert task_count() <= 2
response = client.get("/missing")
assert task_count() <= 2
response = client.get("/trivial")
assert task_count() <= 2
response = client.get("/streaming_task_count")
assert response.text == "1\n2\n3\n"
assert task_count() <= 2
response = client.get("/missing")
assert task_count() <= 2
response = client.get("/trivial")
assert response.text == "Working"