Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read upload files using read(CHUNK_SIZE) rather than iter(). #1948

Merged
merged 4 commits into from Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 24 additions & 4 deletions httpx/_content.py
Expand Up @@ -38,6 +38,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]:


class IteratorByteStream(SyncByteStream):
CHUNK_SIZE = 65_536

def __init__(self, stream: Iterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
Expand All @@ -48,11 +50,21 @@ def __iter__(self) -> Iterator[bytes]:
raise StreamConsumed()

self._is_stream_consumed = True
for part in self._stream:
yield part
if hasattr(self._stream, "read"):
# File-like interfaces should use 'read' directly.
chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore
while chunk:
yield chunk
chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore
else:
# Otherwise iterate.
for part in self._stream:
yield part


class AsyncIteratorByteStream(AsyncByteStream):
CHUNK_SIZE = 65_536

def __init__(self, stream: AsyncIterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
Expand All @@ -63,8 +75,16 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
raise StreamConsumed()

self._is_stream_consumed = True
async for part in self._stream:
yield part
if hasattr(self._stream, "aread"):
# File-like interfaces should use 'aread' directly.
chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore
while chunk:
yield chunk
chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore
else:
# Otherwise iterate.
async for part in self._stream:
yield part


class UnattachedStream(AsyncByteStream, SyncByteStream):
Expand Down
6 changes: 5 additions & 1 deletion httpx/_multipart.py
Expand Up @@ -71,6 +71,8 @@ class FileField:
A single file field item, within a multipart form field.
"""

CHUNK_SIZE = 64 * 1024

def __init__(self, name: str, value: FileTypes) -> None:
self.name = name

Expand Down Expand Up @@ -142,8 +144,10 @@ def render_data(self) -> typing.Iterator[bytes]:
self.file.seek(0)
self._consumed = True

for chunk in self.file:
chunk = self.file.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)

def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_content.py
Expand Up @@ -60,6 +60,31 @@ async def test_bytesio_content():
assert content == b"Hello, world!"


@pytest.mark.asyncio
async def test_async_bytesio_content():
class AsyncBytesIO:
def __init__(self, content):
self._idx = 0
self._content = content

async def aread(self, chunk_size: int):
chunk = self._content[self._idx : self._idx + chunk_size]
self._idx = self._idx + chunk_size
return chunk

async def __aiter__(self):
yield self._content # pragma: nocover

headers, stream = encode_request(content=AsyncBytesIO(b"Hello, world!"))
assert not isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)

content = b"".join([part async for part in stream])

assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"


@pytest.mark.asyncio
async def test_iterator_content():
def hello_world():
Expand Down