diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 51ba556a77..34ee631557 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -78,23 +78,41 @@ def __init__(self, name: str, value: FileTypes) -> None: fileobj: FileContent + headers: typing.Dict[str, str] = {} + content_type: typing.Optional[str] = None + + # This large tuple based API largely mirror's requests' API + # It would be good to think of better APIs for this that we could include in httpx 2.0 + # since variable length tuples (especially of 4 elements) are quite unwieldly if isinstance(value, tuple): - try: - filename, fileobj, content_type = value # type: ignore - except ValueError: + if len(value) == 2: + # neither the 3rd parameter (content_type) nor the 4th (headers) was included filename, fileobj = value # type: ignore - content_type = guess_content_type(filename) + elif len(value) == 3: + filename, fileobj, content_type = value # type: ignore + else: + # all 4 parameters included + filename, fileobj, content_type, headers = value # type: ignore else: filename = Path(str(getattr(value, "name", "upload"))).name fileobj = value + + if content_type is None: content_type = guess_content_type(filename) + has_content_type_header = any("content-type" in key.lower() for key in headers) + if content_type is not None and not has_content_type_header: + # note that unlike requests, we ignore the content_type + # provided in the 3rd tuple element if it is also included in the headers + # requests does the opposite (it overwrites the header with the 3rd tuple element) + headers["Content-Type"] = content_type + if isinstance(fileobj, (str, io.StringIO)): raise TypeError(f"Expected bytes or bytes-like object got: {type(fileobj)}") self.filename = filename self.file = fileobj - self.content_type = content_type + self.headers = headers self._consumed = False def get_length(self) -> int: @@ -122,9 +140,9 @@ def render_headers(self) -> bytes: if self.filename: filename = format_form_param("filename", self.filename) parts.extend([b"; ", filename]) - if self.content_type is not None: - content_type = self.content_type.encode() - parts.extend([b"\r\nContent-Type: ", content_type]) + for header_name, header_value in self.headers.items(): + key, val = f"\r\n{header_name}: ".encode(), header_value.encode() + parts.extend([key, val]) parts.append(b"\r\n\r\n") self._headers = b"".join(parts) diff --git a/httpx/_types.py b/httpx/_types.py index 8cd85cd933..f7ba4486cc 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -89,6 +89,8 @@ Tuple[Optional[str], FileContent], # (filename, file (or bytes), content_type) Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], ] RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] diff --git a/tests/test_multipart.py b/tests/test_multipart.py index cd71a246b3..9980cb5b4e 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -94,6 +94,58 @@ def test_multipart_file_tuple(): assert multipart["file"] == [b""] +@pytest.mark.parametrize("content_type", [None, "text/plain"]) +def test_multipart_file_tuple_headers(content_type: typing.Optional[str]): + file_name = "test.txt" + expected_content_type = "text/plain" + headers = {"Expires": "0"} + + files = {"file": (file_name, io.BytesIO(b""), content_type, headers)} + with mock.patch("os.urandom", return_value=os.urandom(16)): + boundary = os.urandom(16).hex() + + headers, stream = encode_request(data={}, files=files) + assert isinstance(stream, typing.Iterable) + + content = ( + f'--{boundary}\r\nContent-Disposition: form-data; name="file"; ' + f'filename="{file_name}"\r\nExpires: 0\r\nContent-Type: ' + f"{expected_content_type}\r\n\r\n\r\n--{boundary}--\r\n" + "".encode("ascii") + ) + assert headers == { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(len(content)), + } + assert content == b"".join(stream) + + +def test_multipart_headers_include_content_type() -> None: + """Content-Type from 4th tuple parameter (headers) should override the 3rd parameter (content_type)""" + file_name = "test.txt" + expected_content_type = "image/png" + headers = {"Content-Type": "image/png"} + + files = {"file": (file_name, io.BytesIO(b""), "text_plain", headers)} + with mock.patch("os.urandom", return_value=os.urandom(16)): + boundary = os.urandom(16).hex() + + headers, stream = encode_request(data={}, files=files) + assert isinstance(stream, typing.Iterable) + + content = ( + f'--{boundary}\r\nContent-Disposition: form-data; name="file"; ' + f'filename="{file_name}"\r\nContent-Type: ' + f"{expected_content_type}\r\n\r\n\r\n--{boundary}--\r\n" + "".encode("ascii") + ) + assert headers == { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(len(content)), + } + assert content == b"".join(stream) + + def test_multipart_encode(tmp_path: typing.Any) -> None: path = str(tmp_path / "name.txt") with open(path, "wb") as f: