Skip to content

Commit

Permalink
Read upload files using read(CHUNK_SIZE) rather than iter(). (#1948)
Browse files Browse the repository at this point in the history
* Cap upload chunk sizes

* Use '.read' for file streaming, where possible

* Direct iteration should not apply chunk sizes
  • Loading branch information
tomchristie committed Nov 22, 2021
1 parent e232226 commit 6f5865f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
28 changes: 24 additions & 4 deletions httpx/_content.py
Expand Up @@ -38,6 +38,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]:


class IteratorByteStream(SyncByteStream):
CHUNK_SIZE = 65_536

def __init__(self, stream: Iterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
Expand All @@ -48,11 +50,21 @@ def __iter__(self) -> Iterator[bytes]:
raise StreamConsumed()

self._is_stream_consumed = True
for part in self._stream:
yield part
if hasattr(self._stream, "read"):
# File-like interfaces should use 'read' directly.
chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore
while chunk:
yield chunk
chunk = self._stream.read(self.CHUNK_SIZE) # type: ignore
else:
# Otherwise iterate.
for part in self._stream:
yield part


class AsyncIteratorByteStream(AsyncByteStream):
CHUNK_SIZE = 65_536

def __init__(self, stream: AsyncIterable[bytes]):
self._stream = stream
self._is_stream_consumed = False
Expand All @@ -63,8 +75,16 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
raise StreamConsumed()

self._is_stream_consumed = True
async for part in self._stream:
yield part
if hasattr(self._stream, "aread"):
# File-like interfaces should use 'aread' directly.
chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore
while chunk:
yield chunk
chunk = await self._stream.aread(self.CHUNK_SIZE) # type: ignore
else:
# Otherwise iterate.
async for part in self._stream:
yield part


class UnattachedStream(AsyncByteStream, SyncByteStream):
Expand Down
6 changes: 5 additions & 1 deletion httpx/_multipart.py
Expand Up @@ -71,6 +71,8 @@ class FileField:
A single file field item, within a multipart form field.
"""

CHUNK_SIZE = 64 * 1024

def __init__(self, name: str, value: FileTypes) -> None:
self.name = name

Expand Down Expand Up @@ -142,8 +144,10 @@ def render_data(self) -> typing.Iterator[bytes]:
self.file.seek(0)
self._consumed = True

for chunk in self.file:
chunk = self.file.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)

def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_content.py
Expand Up @@ -60,6 +60,31 @@ async def test_bytesio_content():
assert content == b"Hello, world!"


@pytest.mark.asyncio
async def test_async_bytesio_content():
class AsyncBytesIO:
def __init__(self, content):
self._idx = 0
self._content = content

async def aread(self, chunk_size: int):
chunk = self._content[self._idx : self._idx + chunk_size]
self._idx = self._idx + chunk_size
return chunk

async def __aiter__(self):
yield self._content # pragma: nocover

headers, stream = encode_request(content=AsyncBytesIO(b"Hello, world!"))
assert not isinstance(stream, typing.Iterable)
assert isinstance(stream, typing.AsyncIterable)

content = b"".join([part async for part in stream])

assert headers == {"Transfer-Encoding": "chunked"}
assert content == b"Hello, world!"


@pytest.mark.asyncio
async def test_iterator_content():
def hello_world():
Expand Down

0 comments on commit 6f5865f

Please sign in to comment.