diff --git a/httpx/_multipart.py b/httpx/_multipart.py index 0329649758..2c08776f4b 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -122,8 +122,14 @@ def __init__(self, name: str, value: FileTypes) -> None: # requests does the opposite (it overwrites the header with the 3rd tuple element) headers["Content-Type"] = content_type - if isinstance(fileobj, (str, io.StringIO)): - raise TypeError(f"Expected bytes or bytes-like object got: {type(fileobj)}") + if "b" not in getattr(fileobj, "mode", "b"): + raise TypeError( + "Multipart file uploads must be opened in binary mode, not text mode." + ) + if isinstance(fileobj, io.StringIO): + raise TypeError( + "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." + ) self.filename = filename self.file = fileobj diff --git a/httpx/_types.py b/httpx/_types.py index e015844bbf..8099f7b4dd 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -80,7 +80,7 @@ RequestData = Mapping[str, Any] -FileContent = Union[IO[bytes], bytes] +FileContent = Union[IO[bytes], bytes, str] FileTypes = Union[ # file (or bytes) FileContent, diff --git a/tests/test_multipart.py b/tests/test_multipart.py index dc93d26505..a4e9796bd7 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -339,18 +339,37 @@ def test_multipart_encode_files_allows_bytes_content() -> None: assert content == b"".join(stream) -def test_multipart_encode_files_raises_exception_with_str_content() -> None: - files = {"file": ("test.txt", "", "text/plain")} +def test_multipart_encode_files_allows_str_content() -> None: + files = {"file": ("test.txt", "", "text/plain")} with mock.patch("os.urandom", return_value=os.urandom(16)): + boundary = os.urandom(16).hex() - with pytest.raises(TypeError): - encode_request(data={}, files=files) # type: ignore + headers, stream = encode_request(data={}, files=files) + assert isinstance(stream, typing.Iterable) + + content = ( + '--{0}\r\nContent-Disposition: form-data; name="file"; ' + 'filename="test.txt"\r\n' + "Content-Type: text/plain\r\n\r\n\r\n" + "--{0}--\r\n" + "".format(boundary).encode("ascii") + ) + assert headers == { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(len(content)), + } + assert content == b"".join(stream) def test_multipart_encode_files_raises_exception_with_StringIO_content() -> None: files = {"file": ("test.txt", io.StringIO("content"), "text/plain")} - with mock.patch("os.urandom", return_value=os.urandom(16)): + with pytest.raises(TypeError): + encode_request(data={}, files=files) # type: ignore + +def test_multipart_encode_files_raises_exception_with_text_mode_file() -> None: + with tempfile.TemporaryFile(mode="w") as upload: + files = {"file": ("test.txt", upload, "text/plain")} with pytest.raises(TypeError): encode_request(data={}, files=files) # type: ignore