Skip to content

Commit

Permalink
fix 793: allow custom async iterator (#1041)
Browse files Browse the repository at this point in the history
* fix 793

* custom async generator: implement pr notes

* custom async generator: cleanup dependencies

* update tests

* newline at end of tests

* fix linting

* Update tests/test_responses.py

Co-authored-by: Jamie Hewland <jhewland@gmail.com>

* Update tests/test_responses.py

Co-authored-by: Jamie Hewland <jhewland@gmail.com>

* fix naming for custom generator tests

* comply with pep 492

* Shift streaming tests to be in one place

Co-authored-by: witling <noreply@my.email>
Co-authored-by: Jamie Hewland <jhewland@gmail.com>
  • Loading branch information
3 people committed Oct 21, 2020
1 parent c300bdc commit a9f8821
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
3 changes: 1 addition & 2 deletions starlette/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import hashlib
import http.cookies
import inspect
import json
import os
import stat
Expand Down Expand Up @@ -204,7 +203,7 @@ def __init__(
media_type: str = None,
background: BackgroundTask = None,
) -> None:
if inspect.isasyncgen(content):
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ async def numbers_for_cleanup(start=1, stop=5):
assert filled_by_bg_task == "6, 7, 8, 9"


def test_streaming_response_custom_iterator():
async def app(scope, receive, send):
class CustomAsyncIterator:
def __init__(self):
self._called = 0

def __aiter__(self):
return self

async def __anext__(self):
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
return str(self._called)

response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"


def test_streaming_response_custom_iterable():
async def app(scope, receive, send):
class CustomAsyncIterable:
async def __aiter__(self):
for i in range(5):
yield str(i + 1)

response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"


def test_sync_streaming_response():
async def app(scope, receive, send):
def numbers(minimum, maximum):
Expand Down

0 comments on commit a9f8821

Please sign in to comment.