diff --git a/starlette/responses.py b/starlette/responses.py index d6a9462b8..c526de335 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -248,13 +248,14 @@ async def listen_for_disconnect(self, receive: Receive) -> None: break async def stream_response(self, send: Send) -> None: - await send( - { - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - } - ) + with anyio.CancelScope(shield=True): + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) async for chunk in self.body_iterator: if not isinstance(chunk, bytes): chunk = chunk.encode(self.charset) diff --git a/tests/test_responses.py b/tests/test_responses.py index a272559eb..af57288eb 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -2,6 +2,7 @@ import anyio import pytest +from anyio.lowlevel import checkpoint from starlette import status from starlette.background import BackgroundTask @@ -14,6 +15,7 @@ StreamingResponse, ) from starlette.testclient import TestClient +from starlette.types import Message def test_text_response(test_client_factory): @@ -391,3 +393,27 @@ def test_streaming_response_known_size(test_client_factory): client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" + + +@pytest.mark.anyio +async def test_streaming_response_disconnect_should_cancel_after_send_start_returns(): + """ + Test that StreamingResponse cancels after send "http.response.start" returns, + even if there is a checkpoint in send. + """ + send_to_receiver = False + + async def receive() -> Message: + return {"type": "http.disconnect"} + + async def send(msg: Message) -> None: + nonlocal send_to_receiver + assert msg["type"] == "http.response.start" + await checkpoint() # await asyncio.sleep(0) + send_to_receiver = True + + scope = {"type": "http", "method": "GET", "path": "/"} + + response = StreamingResponse(iter("")) + await response(scope, receive, send) + assert send_to_receiver