diff --git a/httpx/_client.py b/httpx/_client.py index c57cfb6ea9..3ba5370f68 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -901,7 +901,7 @@ def send( return response - except Exception as exc: + except BaseException as exc: response.close() raise exc @@ -933,7 +933,7 @@ def _send_handling_auth( request = next_request history.append(response) - except Exception as exc: + except BaseException as exc: response.close() raise exc finally: @@ -972,7 +972,7 @@ def _send_handling_redirects( response.next_request = request return response - except Exception as exc: + except BaseException as exc: response.close() raise exc @@ -1605,7 +1605,7 @@ async def send( return response - except Exception as exc: # pragma: no cover + except BaseException as exc: # pragma: no cover await response.aclose() raise exc @@ -1637,7 +1637,7 @@ async def _send_handling_auth( request = next_request history.append(response) - except Exception as exc: + except BaseException as exc: await response.aclose() raise exc finally: @@ -1677,7 +1677,7 @@ async def _send_handling_redirects( response.next_request = request return response - except Exception as exc: + except BaseException as exc: await response.aclose() raise exc diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 219d612f79..a0ceea22a3 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,3 +1,4 @@ +import asyncio import typing from datetime import timedelta @@ -331,3 +332,16 @@ async def test_server_extensions(server): response = await client.get(url) assert response.status_code == 200 assert response.extensions["http_version"] == b"HTTP/1.1" + + +@pytest.mark.asyncio +async def test_cancelled_response(server): + async with httpx.AsyncClient() as client: + url = server.url.join("/drip?delay=0&duration=0.1") + response = await asyncio.wait_for(client.get(url), 0.2) + assert response.status_code == 200 + assert response.content == b"*" + async with httpx.AsyncClient() as client: + url = server.url.join("/drip?delay=0&duration=0.5") + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(client.get(url), 0.2) diff --git a/tests/conftest.py b/tests/conftest.py index 970c353547..877afe9159 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import threading import time import typing +import urllib.parse import pytest import trustme @@ -76,6 +77,8 @@ async def app(scope, receive, send): assert scope["type"] == "http" if scope["path"].startswith("/slow_response"): await slow_response(scope, receive, send) + if scope["path"].startswith("/drip"): + await drip_response(scope, receive, send) elif scope["path"].startswith("/status"): await status_code(scope, receive, send) elif scope["path"].startswith("/echo_body"): @@ -126,6 +129,29 @@ async def slow_response(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) +async def drip_response(scope, receive, send): + """ + Drips data over a duration after an optional initial delay. + eg: https://httpbin.org/drip?delay=0&duration=1 + """ + qs = urllib.parse.parse_qs(scope["query_string"].decode()) + delay = float(qs.get("delay", ["0"])[0]) + duration = float(qs.get("duration", ["1"])[0]) + await sleep(delay) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + drip = {"type": "http.response.body", "body": b"*", "more_body": True} + for _ in range(int(duration * 10)): + await send(drip) + await sleep(0.1) + await send({"type": "http.response.body", "body": b""}) + + async def status_code(scope, receive, send): status_code = int(scope["path"].replace("/status/", "")) await send(