Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming with BaseHTTPMiddleware: force background to run after response completes #1017

Closed
wants to merge 11 commits into from
20 changes: 15 additions & 5 deletions starlette/middleware/base.py
Expand Up @@ -47,13 +47,23 @@ async def coro() -> None:
assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
while True:
message = await queue.get()
if message is None:
break
# In non-streaming responses, there should be one message to emit
message = await queue.get()
if message is not None:
assert message["type"] == "http.response.body"
yield message.get("body", b"")
task.result()

while message.get("more_body"):
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
message = await queue.get()
if message is None:
break
assert message["type"] == "http.response.body"
yield message.get("body", b"")

if task.done():
# Check for exceptions and raise if present.
# Incomplete tasks may still have background tasks to run.
task.result()

response = StreamingResponse(
status_code=message["status"], content=body_stream()
Expand Down
48 changes: 47 additions & 1 deletion tests/middleware/test_base.py
Expand Up @@ -3,7 +3,7 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.testclient import TestClient

Expand Down Expand Up @@ -143,3 +143,49 @@ def homepage(request):
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"


async def numbers_stream(minimum, maximum):
erewok marked this conversation as resolved.
Show resolved Hide resolved
yield ("<html><body><ul>")
for number in range(minimum, maximum + 1):
yield "<li>%d</li>" % number
yield ("</ul></body></html>")


async def somthing_broken(minimum, maximum, error_at=2):
if error_at <= 0:
raise RuntimeError("This is a stream that breaks when it starts")
yield ("<html><body><ul>")
for number in range(minimum, maximum + 1):
yield "<li>%d</li>" % number
if number >= error_at:
raise RuntimeError("This is a broken stream")


@app.route("/streaming")
async def some_streaming(_):
return StreamingResponse(numbers_stream(1, 3))


@app.route("/broken-streaming/{error_at:int}")
async def some_broken_streaming(request):
error_at = request.path_params["error_at"]
return StreamingResponse(somthing_broken(1, 5, error_at=error_at))


def test_custom_middleware_streaming():
client = TestClient(app)
response = client.get("/streaming")
assert response.headers["Custom-Header"] == "Example"
assert (
response.text
== "<html><body><ul><li>1</li><li>2</li><li>3</li></ul></body></html>"
)

with pytest.raises(RuntimeError):
# after body streaming has started
response = client.get("/broken-streaming/2")
with pytest.raises(RuntimeError):
# right before body stream starts (only start message emitted)
# this should trigger _first_ message being None
response = client.get("/broken-streaming/0")