diff --git a/httpx/_client.py b/httpx/_client.py index cec0d63589..ce7b92cc78 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -900,7 +900,7 @@ def send( return response - except Exception as exc: + except BaseException as exc: response.close() raise exc @@ -932,7 +932,7 @@ def _send_handling_auth( request = next_request history.append(response) - except Exception as exc: + except BaseException as exc: response.close() raise exc finally: @@ -971,7 +971,7 @@ def _send_handling_redirects( response.next_request = request return response - except Exception as exc: + except BaseException as exc: response.close() raise exc @@ -1604,7 +1604,7 @@ async def send( return response - except Exception as exc: # pragma: no cover + except BaseException as exc: # pragma: no cover await response.aclose() raise exc @@ -1636,7 +1636,7 @@ async def _send_handling_auth( request = next_request history.append(response) - except Exception as exc: + except BaseException as exc: await response.aclose() raise exc finally: @@ -1676,7 +1676,7 @@ async def _send_handling_redirects( response.next_request = request return response - except Exception as exc: + except BaseException as exc: await response.aclose() raise exc diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 219d612f79..da2387df42 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -324,6 +324,46 @@ async def hello_world(request): assert response.text == "Hello, world!" +@pytest.mark.usefixtures("async_environment") +async def test_cancellation_during_stream(): + """ + If any BaseException is raised during streaming the response, then the + stream should be closed. + + This includes: + + * `asyncio.CancelledError` (A subclass of BaseException from Python 3.8 onwards.) + * `trio.Cancelled` + * `KeyboardInterrupt` + * `SystemExit` + + See https://github.com/encode/httpx/issues/2139 + """ + stream_was_closed = False + + def response_with_cancel_during_stream(request): + class CancelledStream(httpx.AsyncByteStream): + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield b"Hello" + raise KeyboardInterrupt() + yield b", world" # pragma: nocover + + async def aclose(self) -> None: + nonlocal stream_was_closed + stream_was_closed = True + + return httpx.Response( + 200, headers={"Content-Length": "12"}, stream=CancelledStream() + ) + + transport = httpx.MockTransport(response_with_cancel_during_stream) + + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(KeyboardInterrupt): + await client.get("https://www.example.com") + assert stream_was_closed + + @pytest.mark.usefixtures("async_environment") async def test_server_extensions(server): url = server.url