diff --git a/httpx/_content.py b/httpx/_content.py index 24a967d506..eb7a7aef17 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -150,7 +150,7 @@ def encode_urlencoded_data( def encode_multipart_data( - data: dict, files: RequestFiles, boundary: Optional[bytes] = None + data: dict, files: RequestFiles, boundary: Optional[bytes] ) -> Tuple[Dict[str, str], MultipartStream]: multipart = MultipartStream(data=data, files=files, boundary=boundary) headers = multipart.get_headers() diff --git a/httpx/_models.py b/httpx/_models.py index 8879532c81..fd1d7fe9a1 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -27,6 +27,7 @@ StreamConsumed, request_context, ) +from ._multipart import get_multipart_boundary_from_content_type from ._status_codes import codes from ._types import ( AsyncByteStream, @@ -332,7 +333,18 @@ def __init__( Cookies(cookies).set_cookie_header(self) if stream is None: - headers, stream = encode_request(content, data, files, json) + content_type: typing.Optional[str] = self.headers.get("content-type") + headers, stream = encode_request( + content=content, + data=data, + files=files, + json=json, + boundary=get_multipart_boundary_from_content_type( + content_type=content_type.encode(self.headers.encoding) + if content_type + else None + ), + ) self._prepare(headers) self.stream = stream # Load the request body, except for streaming content. diff --git a/httpx/_multipart.py b/httpx/_multipart.py index d42f5cb31b..8bd7a17c9b 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -20,6 +20,20 @@ ) +def get_multipart_boundary_from_content_type( + content_type: typing.Optional[bytes], +) -> typing.Optional[bytes]: + if not content_type or not content_type.startswith(b"multipart/form-data"): + return None + # parse boundary according to + # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 + if b";" in content_type: + for section in content_type.split(b";"): + if section.strip().lower().startswith(b"boundary="): + return section.strip()[len(b"boundary=") :].strip(b'"') + return None + + class DataField: """ A single form field item, within a multipart form field. diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 46ad0e01a1..dc93d26505 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -42,6 +42,58 @@ def test_multipart(value, output): assert multipart["file"] == [b""] +@pytest.mark.parametrize( + "header", + [ + "multipart/form-data; boundary=+++; charset=utf-8", + "multipart/form-data; charset=utf-8; boundary=+++", + "multipart/form-data; boundary=+++", + "multipart/form-data; boundary=+++ ;", + 'multipart/form-data; boundary="+++"; charset=utf-8', + 'multipart/form-data; charset=utf-8; boundary="+++"', + 'multipart/form-data; boundary="+++"', + 'multipart/form-data; boundary="+++" ;', + ], +) +def test_multipart_explicit_boundary(header: str) -> None: + client = httpx.Client(transport=httpx.MockTransport(echo_request_content)) + + files = {"file": io.BytesIO(b"")} + headers = {"content-type": header} + response = client.post("http://127.0.0.1:8000/", files=files, headers=headers) + assert response.status_code == 200 + + # We're using the cgi module to verify the behavior here, which is a + # bit grungy, but sufficient just for our testing purposes. + assert response.request.headers["Content-Type"] == header + content_length = response.request.headers["Content-Length"] + pdict: dict = { + "boundary": b"+++", + "CONTENT-LENGTH": content_length, + } + multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict) + + assert multipart["file"] == [b""] + + +@pytest.mark.parametrize( + "header", + [ + "multipart/form-data; charset=utf-8", + "multipart/form-data; charset=utf-8; ", + ], +) +def test_multipart_header_without_boundary(header: str) -> None: + client = httpx.Client(transport=httpx.MockTransport(echo_request_content)) + + files = {"file": io.BytesIO(b"")} + headers = {"content-type": header} + response = client.post("http://127.0.0.1:8000/", files=files, headers=headers) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == header + + @pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None)) def test_multipart_invalid_key(key): client = httpx.Client(transport=httpx.MockTransport(echo_request_content))