diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 2c08776f4b..1d46d96a98 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -135,19 +135,18 @@ def __init__(self, name: str, value: FileTypes) -> None: self.file = fileobj self.headers = headers - def get_length(self) -> int: + def get_length(self) -> typing.Optional[int]: headers = self.render_headers() if isinstance(self.file, (str, bytes)): return len(headers) + len(to_bytes(self.file)) - # Let's do our best not to read `file` into memory. file_length = peek_filelike_length(self.file) + + # If we can't determine the filesize without reading it into memory, + # then return `None` here, to indicate an unknown file length. if file_length is None: - # As a last resort, read file and cache contents for later. - assert not hasattr(self, "_data") - self._data = to_bytes(self.file.read()) - file_length = len(self._data) + return None return len(headers) + file_length @@ -173,13 +172,11 @@ def render_data(self) -> typing.Iterator[bytes]: yield to_bytes(self.file) return - if hasattr(self, "_data"): - # Already rendered. - yield self._data - return - if hasattr(self.file, "seek"): - self.file.seek(0) + try: + self.file.seek(0) + except io.UnsupportedOperation: + pass chunk = self.file.read(self.CHUNK_SIZE) while chunk: @@ -232,24 +229,34 @@ def iter_chunks(self) -> typing.Iterator[bytes]: yield b"\r\n" yield b"--%s--\r\n" % self.boundary - def iter_chunks_lengths(self) -> typing.Iterator[int]: + def get_content_length(self) -> typing.Optional[int]: + """ + Return the length of the multipart encoded content, or `None` if + any of the files have a length that cannot be determined upfront. + """ boundary_length = len(self.boundary) - # Follow closely what `.iter_chunks()` does. + length = 0 + for field in self.fields: - yield 2 + boundary_length + 2 - yield field.get_length() - yield 2 - yield 2 + boundary_length + 4 + field_length = field.get_length() + if field_length is None: + return None + + length += 2 + boundary_length + 2 # b"--{boundary}\r\n" + length += field_length + length += 2 # b"\r\n" - def get_content_length(self) -> int: - return sum(self.iter_chunks_lengths()) + length += 2 + boundary_length + 4 # b"--{boundary}--\r\n" + return length # Content stream interface. def get_headers(self) -> typing.Dict[str, str]: - content_length = str(self.get_content_length()) + content_length = self.get_content_length() content_type = self.content_type - return {"Content-Length": content_length, "Content-Type": content_type} + if content_length is None: + return {"Transfer-Encoding": "chunked", "Content-Type": content_type} + return {"Content-Length": str(content_length), "Content-Type": content_type} def __iter__(self) -> typing.Iterator[bytes]: for chunk in self.iter_chunks(): diff --git a/tests/test_multipart.py b/tests/test_multipart.py index e9ce928a16..6d281ed7d0 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -380,8 +380,9 @@ def test_multipart_encode_files_raises_exception_with_text_mode_file() -> None: def test_multipart_encode_non_seekable_filelike() -> None: """ - Test that special readable but non-seekable filelike objects are supported, - at the cost of reading them into memory at most once. + Test that special readable but non-seekable filelike objects are supported. + In this case uploads with use 'Transfer-Encoding: chunked', instead of + a 'Content-Length' header. """ class IteratorIO(io.IOBase): @@ -410,7 +411,7 @@ def data() -> typing.Iterator[bytes]: ) assert headers == { "Content-Type": "multipart/form-data; boundary=+++", - "Content-Length": str(len(content)), + "Transfer-Encoding": "chunked", } assert content == b"".join(stream)