Skip to content

Commit

Permalink
allow setting an explicit multipart boundary via headers (#2278)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Aug 15, 2022
1 parent 2434e65 commit 1526048
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 2 deletions.
2 changes: 1 addition & 1 deletion httpx/_content.py
Expand Up @@ -150,7 +150,7 @@ def encode_urlencoded_data(


def encode_multipart_data(
data: dict, files: RequestFiles, boundary: Optional[bytes] = None
data: dict, files: RequestFiles, boundary: Optional[bytes]
) -> Tuple[Dict[str, str], MultipartStream]:
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
Expand Down
14 changes: 13 additions & 1 deletion httpx/_models.py
Expand Up @@ -27,6 +27,7 @@
StreamConsumed,
request_context,
)
from ._multipart import get_multipart_boundary_from_content_type
from ._status_codes import codes
from ._types import (
AsyncByteStream,
Expand Down Expand Up @@ -332,7 +333,18 @@ def __init__(
Cookies(cookies).set_cookie_header(self)

if stream is None:
headers, stream = encode_request(content, data, files, json)
content_type: typing.Optional[str] = self.headers.get("content-type")
headers, stream = encode_request(
content=content,
data=data,
files=files,
json=json,
boundary=get_multipart_boundary_from_content_type(
content_type=content_type.encode(self.headers.encoding)
if content_type
else None
),
)
self._prepare(headers)
self.stream = stream
# Load the request body, except for streaming content.
Expand Down
14 changes: 14 additions & 0 deletions httpx/_multipart.py
Expand Up @@ -20,6 +20,20 @@
)


def get_multipart_boundary_from_content_type(
content_type: typing.Optional[bytes],
) -> typing.Optional[bytes]:
if not content_type or not content_type.startswith(b"multipart/form-data"):
return None
# parse boundary according to
# https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
if b";" in content_type:
for section in content_type.split(b";"):
if section.strip().lower().startswith(b"boundary="):
return section.strip()[len(b"boundary=") :].strip(b'"')
return None


class DataField:
"""
A single form field item, within a multipart form field.
Expand Down
52 changes: 52 additions & 0 deletions tests/test_multipart.py
Expand Up @@ -42,6 +42,58 @@ def test_multipart(value, output):
assert multipart["file"] == [b"<file content>"]


@pytest.mark.parametrize(
"header",
[
"multipart/form-data; boundary=+++; charset=utf-8",
"multipart/form-data; charset=utf-8; boundary=+++",
"multipart/form-data; boundary=+++",
"multipart/form-data; boundary=+++ ;",
'multipart/form-data; boundary="+++"; charset=utf-8',
'multipart/form-data; charset=utf-8; boundary="+++"',
'multipart/form-data; boundary="+++"',
'multipart/form-data; boundary="+++" ;',
],
)
def test_multipart_explicit_boundary(header: str) -> None:
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))

files = {"file": io.BytesIO(b"<file content>")}
headers = {"content-type": header}
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
assert response.status_code == 200

# We're using the cgi module to verify the behavior here, which is a
# bit grungy, but sufficient just for our testing purposes.
assert response.request.headers["Content-Type"] == header
content_length = response.request.headers["Content-Length"]
pdict: dict = {
"boundary": b"+++",
"CONTENT-LENGTH": content_length,
}
multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)

assert multipart["file"] == [b"<file content>"]


@pytest.mark.parametrize(
"header",
[
"multipart/form-data; charset=utf-8",
"multipart/form-data; charset=utf-8; ",
],
)
def test_multipart_header_without_boundary(header: str) -> None:
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))

files = {"file": io.BytesIO(b"<file content>")}
headers = {"content-type": header}
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)

assert response.status_code == 200
assert response.request.headers["Content-Type"] == header


@pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
def test_multipart_invalid_key(key):
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
Expand Down

0 comments on commit 1526048

Please sign in to comment.