Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix 793: allow custom async iterator #1041

Merged
merged 15 commits into from
Oct 21, 2020
3 changes: 1 addition & 2 deletions starlette/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import hashlib
import http.cookies
import inspect
import json
import os
import stat
Expand Down Expand Up @@ -204,7 +203,7 @@ def __init__(
media_type: str = None,
background: BackgroundTask = None,
) -> None:
if inspect.isasyncgen(content):
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ async def numbers_for_cleanup(start=1, stop=5):
assert filled_by_bg_task == "6, 7, 8, 9"


def test_streaming_response_custom_iterator():
async def app(scope, receive, send):
class CustomAsyncIterator:
def __init__(self):
self._called = 0

def __aiter__(self):
return self

async def __anext__(self):
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
return str(self._called)

response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"


def test_streaming_response_custom_iterable():
async def app(scope, receive, send):
class CustomAsyncIterable:
async def __aiter__(self):
for i in range(5):
yield str(i + 1)

response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"


def test_sync_streaming_response():
async def app(scope, receive, send):
def numbers(minimum, maximum):
Expand Down