Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix 793: allow custom async iterator #1041

Merged
merged 15 commits into from
Oct 21, 2020
3 changes: 1 addition & 2 deletions starlette/responses.py
@@ -1,6 +1,5 @@
import hashlib
import http.cookies
import inspect
import json
import os
import stat
Expand Down Expand Up @@ -204,7 +203,7 @@ def __init__(
media_type: str = None,
background: BackgroundTask = None,
) -> None:
if inspect.isasyncgen(content):
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_responses.py
Expand Up @@ -269,3 +269,39 @@ def test_head_method():
client = TestClient(app)
response = client.head("/")
assert response.text == ""

def test_sync_custom_streaming_response():
async def app(scope, receive, send):
class CustomAsyncGenerator:
witling marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
self._called = 0

def __aiter__(self):
return self

async def __anext__(self):
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
return str(self._called)

response = StreamingResponse(CustomAsyncGenerator(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"

def test_sync_custom_streaming_response_no_anext():
async def app(scope, receive, send):
class CustomAsyncGenerator:
async def __aiter__(self):
witling marked this conversation as resolved.
Show resolved Hide resolved
for i in range(5):
yield str(i + 1)

response = StreamingResponse(CustomAsyncGenerator(), media_type="text/plain")
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.text == "12345"
witling marked this conversation as resolved.
Show resolved Hide resolved