Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow setting an explicit multipart boundary via headers #2278

Merged
merged 21 commits into from Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 14 additions & 5 deletions httpx/_content.py
Expand Up @@ -15,7 +15,7 @@
from urllib.parse import urlencode

from ._exceptions import StreamClosed, StreamConsumed
from ._multipart import MultipartStream
from ._multipart import MultipartStream, get_multipart_boundary_from_content_type
from ._types import (
AsyncByteStream,
RequestContent,
Expand Down Expand Up @@ -150,11 +150,19 @@ def encode_urlencoded_data(


def encode_multipart_data(
data: dict, files: RequestFiles, boundary: Optional[bytes] = None
data: dict,
files: RequestFiles,
boundary: Optional[bytes],
content_type: Optional[str],
) -> Tuple[Dict[str, str], MultipartStream]:
# note: we are the only ones calling into this function
# (not users) so there should never be a situation where
# both content_type and boundary are set
if content_type:
boundary = get_multipart_boundary_from_content_type(content_type)
jhominal marked this conversation as resolved.
Show resolved Hide resolved
multipart = MultipartStream(data=data, files=files, boundary=boundary)
headers = multipart.get_headers()
return headers, multipart
new_headers = multipart.get_headers()
return new_headers, multipart


def encode_text(text: str) -> Tuple[Dict[str, str], ByteStream]:
Expand Down Expand Up @@ -187,6 +195,7 @@ def encode_request(
files: Optional[RequestFiles] = None,
json: Optional[Any] = None,
boundary: Optional[bytes] = None,
content_type: Optional[str] = None,
) -> Tuple[Dict[str, str], Union[SyncByteStream, AsyncByteStream]]:
"""
Handles encoding the given `content`, `data`, `files`, and `json`,
Expand All @@ -207,7 +216,7 @@ def encode_request(
if content is not None:
return encode_content(content)
elif files:
return encode_multipart_data(data or {}, files, boundary)
return encode_multipart_data(data or {}, files, boundary, content_type)
elif data:
return encode_urlencoded_data(data)
elif json is not None:
Expand Down
8 changes: 7 additions & 1 deletion httpx/_models.py
Expand Up @@ -332,7 +332,13 @@ def __init__(
Cookies(cookies).set_cookie_header(self)

if stream is None:
headers, stream = encode_request(content, data, files, json)
headers, stream = encode_request(
content=content,
data=data,
files=files,
json=json,
content_type=self.headers.get("content-type"),
)
jhominal marked this conversation as resolved.
Show resolved Hide resolved
self._prepare(headers)
self.stream = stream
# Load the request body, except for streaming content.
Expand Down
10 changes: 10 additions & 0 deletions httpx/_multipart.py
Expand Up @@ -20,6 +20,16 @@
)


def get_multipart_boundary_from_content_type(
content_type: str,
) -> typing.Optional[bytes]:
if ";" in content_type:
for section in content_type.split(";"):
if section.strip().startswith("boundary="):
return section.strip().split("boundary=")[-1].encode("latin-1")
jhominal marked this conversation as resolved.
Show resolved Hide resolved
return None


class DataField:
"""
A single form field item, within a multipart form field.
Expand Down
48 changes: 48 additions & 0 deletions tests/test_multipart.py
Expand Up @@ -42,6 +42,54 @@ def test_multipart(value, output):
assert multipart["file"] == [b"<file content>"]


@pytest.mark.parametrize(
"header",
[
"multipart/form-data; boundary=+++; charset=utf-8",
"multipart/form-data; charset=utf-8; boundary=+++",
"multipart/form-data; boundary=+++",
"multipart/form-data; boundary=+++ ;",
],
)
def test_multipart_explicit_boundary(header: str) -> None:
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))

files = {"file": io.BytesIO(b"<file content>")}
headers = {"content-type": header}
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)
assert response.status_code == 200

# We're using the cgi module to verify the behavior here, which is a
# bit grungy, but sufficient just for our testing purposes.
assert response.request.headers["Content-Type"] == header
content_length = response.request.headers["Content-Length"]
pdict: dict = {
"boundary": b"+++",
"CONTENT-LENGTH": content_length,
}
multipart = cgi.parse_multipart(io.BytesIO(response.content), pdict)

assert multipart["file"] == [b"<file content>"]


@pytest.mark.parametrize(
"header",
[
"multipart/form-data; charset=utf-8",
"multipart/form-data; charset=utf-8; ",
],
)
def test_multipart_header_without_boundary(header: str) -> None:
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))

files = {"file": io.BytesIO(b"<file content>")}
headers = {"content-type": header}
response = client.post("http://127.0.0.1:8000/", files=files, headers=headers)

assert response.status_code == 200
assert response.request.headers["Content-Type"] == header


@pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
def test_multipart_invalid_key(key):
client = httpx.Client(transport=httpx.MockTransport(echo_request_content))
Expand Down