From 0088253b32c7e6541ce85f40429f133c01ca35dc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 4 Feb 2022 14:48:57 +0000 Subject: [PATCH] Always rewind files on multipart uploads. (#2065) * Test for multipart POST same file twice. * Always rewind files on multipart uploads * Linting --- httpx/_multipart.py | 4 +--- tests/test_multipart.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 34ee631557..3f981c85d0 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -113,7 +113,6 @@ def __init__(self, name: str, value: FileTypes) -> None: self.filename = filename self.file = fileobj self.headers = headers - self._consumed = False def get_length(self) -> int: headers = self.render_headers() @@ -158,9 +157,8 @@ def render_data(self) -> typing.Iterator[bytes]: yield self._data return - if self._consumed: # pragma: nocover + if hasattr(self.file, "seek"): self.file.seek(0) - self._consumed = True chunk = self.file.read(self.CHUNK_SIZE) while chunk: diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 9980cb5b4e..46ad0e01a1 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,6 +1,7 @@ import cgi import io import os +import tempfile import typing from unittest import mock @@ -339,6 +340,25 @@ def data() -> typing.Iterator[bytes]: assert content == b"".join(stream) +def test_multipart_rewinds_files(): + with tempfile.TemporaryFile() as upload: + upload.write(b"Hello, world!") + + transport = httpx.MockTransport(echo_request_content) + client = httpx.Client(transport=transport) + + files = {"file": upload} + response = client.post("http://127.0.0.1:8000/", files=files) + assert response.status_code == 200 + assert b"\r\nHello, world!\r\n" in response.content + + # POSTing the same file instance a second time should have the same content. + files = {"file": upload} + response = client.post("http://127.0.0.1:8000/", files=files) + assert response.status_code == 200 + assert b"\r\nHello, world!\r\n" in response.content + + class TestHeaderParamHTML5Formatting: def test_unicode(self): param = format_form_param("filename", "n\u00e4me")