Skip to content

Commit

Permalink
Stream response body in ASGITransport
Browse files Browse the repository at this point in the history
  • Loading branch information
jhominal committed Jan 16, 2024
1 parent cf989ae commit dcc23a2
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 8 deletions.
86 changes: 78 additions & 8 deletions httpx/_transports/asgi.py
@@ -1,3 +1,4 @@
import types
import typing

import sniffio
Expand Down Expand Up @@ -33,12 +34,76 @@ def create_event() -> "Event":
return asyncio.Event()


class _AwaitableRunner:
def __init__(self, awaitable: typing.Awaitable[typing.Any]):
self._generator = awaitable.__await__()
self._started = False
self._next_item: typing.Any = None
self._finished = False

@types.coroutine
def __call__(
self, *, until: typing.Optional[typing.Callable[[], bool]] = None
) -> typing.Generator[typing.Any, typing.Any, typing.Any]:
while not self._finished and (until is None or not until()):
send_value, throw_value = None, None
if self._started:
try:
send_value = yield self._next_item
except BaseException as e:
throw_value = e

self._started = True
try:
if throw_value is not None:
self._next_item = self._generator.throw(throw_value)
else:
self._next_item = self._generator.send(send_value)
except StopIteration as e:
self._exception = None
self._finished = True
return e.value
except BaseException:
self._generator.close()
self._finished = True
raise


class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: typing.List[bytes]) -> None:
def __init__(
self,
body: typing.List[bytes],
raise_app_exceptions: bool,
response_complete: "Event",
app_runner: _AwaitableRunner,
) -> None:
self._body = body
self._raise_app_exceptions = raise_app_exceptions
self._response_complete = response_complete
self._app_runner = app_runner

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
try:
while bool(self._body) or not self._response_complete.is_set():
if self._body:
yield b"".join(self._body)
self._body.clear()
await self._app_runner(
until=lambda: bool(self._body) or self._response_complete.is_set()
)
except Exception: # noqa: PIE786
if self._raise_app_exceptions:
raise
finally:
await self.aclose()

async def aclose(self) -> None:
self._response_complete.set()
try:
await self._app_runner()
except Exception: # noqa: PIE786
if self._raise_app_exceptions:
raise


class ASGITransport(AsyncBaseTransport):
Expand Down Expand Up @@ -145,8 +210,10 @@ async def send(message: _Message) -> None:
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
elif (
message["type"] == "http.response.body"
and not response_complete.is_set()
):
body = message.get("body", b"")
more_body = message.get("more_body", False)

Expand All @@ -156,9 +223,11 @@ async def send(message: _Message) -> None:
if not more_body:
response_complete.set()

app_runner = _AwaitableRunner(self.app(scope, receive, send))

try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
await app_runner(until=lambda: response_started)
except Exception: # noqa: PIE786
if self.raise_app_exceptions:
raise

Expand All @@ -168,10 +237,11 @@ async def send(message: _Message) -> None:
if response_headers is None:
response_headers = {}

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = ASGIResponseStream(body_parts)
stream = ASGIResponseStream(
body_parts, self.raise_app_exceptions, response_complete, app_runner
)

return Response(status_code, headers=response_headers, stream=stream)
79 changes: 79 additions & 0 deletions tests/test_asgi.py
@@ -1,5 +1,6 @@
import json

import anyio
import pytest

import httpx
Expand Down Expand Up @@ -60,13 +61,24 @@ async def raise_exc(scope, receive, send):
raise RuntimeError()


async def raise_exc_after_response_start(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await anyio.sleep(0)
raise RuntimeError()


async def raise_exc_after_response(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await send({"type": "http.response.body", "body": output})
await anyio.sleep(0)
raise RuntimeError()


Expand Down Expand Up @@ -165,6 +177,13 @@ async def test_asgi_exc():
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response_start():
async with httpx.AsyncClient(app=raise_exc_after_response_start) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response():
async with httpx.AsyncClient(app=raise_exc_after_response) as client:
Expand Down Expand Up @@ -213,3 +232,63 @@ async def test_asgi_exc_no_raise():
response = await client.get("http://www.example.org/")

assert response.status_code == 500


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response_start():
transport = httpx.ASGITransport(
app=raise_exc_after_response_start, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response():
transport = httpx.ASGITransport(
app=raise_exc_after_response, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.anyio
async def test_asgi_stream_returns_before_waiting_for_body():
start_response_body = anyio.Event()

async def send_response_body_after_event(scope, receive, send):
status = 200
headers = [(b"content-type", b"text/plain")]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
await start_response_body.wait()
await send({"type": "http.response.body", "body": b"body", "more_body": False})

async with httpx.AsyncClient(app=send_response_body_after_event) as client:
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
start_response_body.set()
await response.aread()
assert response.text == "body"


@pytest.mark.anyio
async def test_asgi_can_be_canceled():
# This test exists to cover transmission of the cancellation exception through
# _AwaitableRunner
app_started = anyio.Event()

async def never_return(scope, receive, send):
app_started.set()
await anyio.sleep_forever()

async with httpx.AsyncClient(app=never_return) as client:
async with anyio.create_task_group() as task_group:
task_group.start_soon(client.get, "http://www.example.org/")
await app_started.wait()
task_group.cancel_scope.cancel()

0 comments on commit dcc23a2

Please sign in to comment.