From 9d745e38be4c31e184fcc33b71b29db3915a182f Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Sun, 26 Jun 2022 00:28:24 +0800 Subject: [PATCH 1/4] Shield send http.response.start from cancellation `RuntimeError: No response returned.` is raised in BaseHTTPMiddleware if request is disconnected, due to `task_group.cancel_scope.cancel()` in StreamingResponse.__call__..wrap and cancellation check in `await checkpoint()` of MemoryObjectSendStream.send. Let's fix this behaviour change caused by anyio integration in 0.15.0. --- starlette/responses.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 98c8caf1b..4f2fa745b 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -240,13 +240,14 @@ async def listen_for_disconnect(self, receive: Receive) -> None: break async def stream_response(self, send: Send) -> None: - await send( - { - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - } - ) + with anyio.CancelScope(shield=True): + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) async for chunk in self.body_iterator: if not isinstance(chunk, bytes): chunk = chunk.encode(self.charset) From 260df6afaae988cab021c5fd7cc68a1175071c53 Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Mon, 27 Jun 2022 20:00:17 +0800 Subject: [PATCH 2/4] Test StreamingResponse send http.response.start --- tests/test_responses.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_responses.py b/tests/test_responses.py index a272559eb..b4ec0be0a 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -2,6 +2,7 @@ import anyio import pytest +from anyio.lowlevel import checkpoint from starlette import status from starlette.background import BackgroundTask @@ -14,6 +15,7 @@ StreamingResponse, ) from starlette.testclient import TestClient +from starlette.types import Message def test_text_response(test_client_factory): @@ -391,3 +393,26 @@ def test_streaming_response_known_size(test_client_factory): client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" + + +@pytest.mark.anyio +async def test_streaming_response_disconnect_should_cancel_after_send_http_response_start_returns(): + """ + Test that StreamingResponse cancels after send "http.response.start" returns, even if there is a checkpoint in send. + """ + send_to_receiver = False + + async def receive() -> Message: + return {"type": "http.disconnect"} + + async def send(msg: Message) -> None: + nonlocal send_to_receiver + assert msg["type"] == "http.response.start" + await checkpoint() # await asyncio.sleep(0) + send_to_receiver = True + + scope = {"type": "http", "method": "GET", "path": "/"} + + response = StreamingResponse("") + await response(scope, receive, send) + assert send_to_receiver From d53845c6240083aef6cea4ad7d1de102fbbe735b Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Mon, 27 Jun 2022 21:28:28 +0800 Subject: [PATCH 3/4] Fix E501 line too long --- tests/test_responses.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index b4ec0be0a..6afa5a592 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -396,9 +396,10 @@ def test_streaming_response_known_size(test_client_factory): @pytest.mark.anyio -async def test_streaming_response_disconnect_should_cancel_after_send_http_response_start_returns(): +async def test_streaming_response_disconnect_should_cancel_after_send_start_returns(): """ - Test that StreamingResponse cancels after send "http.response.start" returns, even if there is a checkpoint in send. + Test that StreamingResponse cancels after send "http.response.start" returns, + even if there is a checkpoint in send. """ send_to_receiver = False From 11c011dd4601aaf8acd28dc1c286a49f9e7d0786 Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Mon, 27 Jun 2022 21:44:58 +0800 Subject: [PATCH 4/4] Fix arg-type --- tests/test_responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_responses.py b/tests/test_responses.py index 6afa5a592..af57288eb 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -414,6 +414,6 @@ async def send(msg: Message) -> None: scope = {"type": "http", "method": "GET", "path": "/"} - response = StreamingResponse("") + response = StreamingResponse(iter("")) await response(scope, receive, send) assert send_to_receiver