Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add more type checking in params/form-data #3173

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/test-suite.yml
Expand Up @@ -7,13 +7,18 @@ on:
pull_request:
branches: ["master"]

defaults:
run:
shell: bash

jobs:
tests:
name: "Python ${{ matrix.python-version }}"
runs-on: "ubuntu-latest"
name: "Python ${{ matrix.python-version }} (${{ matrix.os }})"
runs-on: "${{ matrix.os }}-latest"

strategy:
matrix:
os: ["ubuntu", "windows"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
Expand All @@ -32,3 +37,4 @@ jobs:
run: "scripts/test"
- name: "Enforce coverage"
run: "scripts/coverage"
if: "matrix.os == 'ubuntu'"
5 changes: 3 additions & 2 deletions httpx/_types.py
Expand Up @@ -2,6 +2,7 @@
Type definitions for type checking purposes.
"""

import enum
import ssl
from http.cookiejar import CookieJar
from typing import (
Expand Down Expand Up @@ -31,7 +32,7 @@
from ._urls import URL, QueryParams # noqa: F401


PrimitiveData = Optional[Union[str, int, float, bool]]
PrimitiveData = Optional[Union[str, int, float, bool, enum.Enum]]

RawURL = NamedTuple(
"RawURL",
Expand Down Expand Up @@ -91,7 +92,7 @@
ResponseContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
ResponseExtensions = MutableMapping[str, Any]

RequestData = Mapping[str, Any]
RequestData = Mapping[str, PrimitiveData | Sequence[PrimitiveData]]

FileContent = Union[IO[bytes], bytes, str]
FileTypes = Union[
Expand Down
10 changes: 9 additions & 1 deletion httpx/_utils.py
Expand Up @@ -2,6 +2,7 @@

import codecs
import email.message
import enum
import ipaddress
import mimetypes
import os
Expand Down Expand Up @@ -65,7 +66,14 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return "false"
elif value is None:
return ""
return str(value)
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, str):
return value
elif isinstance(value, enum.Enum):
# StrEnum and IntEnum is handled above
return primitive_value_to_str(value.value)
raise TypeError(f"unsupported data type {type(value)}")


def is_known_encoding(encoding: str) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion tests/client/test_async_client.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing
from datetime import timedelta

Expand All @@ -18,7 +19,9 @@ async def test_get(server):
assert response.http_version == "HTTP/1.1"
assert response.headers
assert repr(response) == "<Response [200 OK]>"
assert response.elapsed > timedelta(seconds=0)
if sys.platform != "win32":
# flaky on windows
assert response.elapsed > timedelta(seconds=0)


@pytest.mark.parametrize(
Expand Down
25 changes: 21 additions & 4 deletions tests/test_content.py
@@ -1,3 +1,4 @@
import enum
import io
import typing

Expand Down Expand Up @@ -182,7 +183,14 @@ async def test_json_content():

@pytest.mark.anyio
async def test_urlencoded_content():
request = httpx.Request(method, url, data={"Hello": "world!"})
class Flag(enum.Enum):
flag = "f"

request = httpx.Request(
method,
url,
data={"Hello": "world!", "foo": Flag.flag, "like": True, "bar": 123},
)
assert isinstance(request.stream, typing.Iterable)
assert isinstance(request.stream, typing.AsyncIterable)

Expand All @@ -191,11 +199,11 @@ async def test_urlencoded_content():

assert request.headers == {
"Host": "www.example.com",
"Content-Length": "14",
"Content-Length": "38",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"Hello=world%21"
assert async_content == b"Hello=world%21"
assert sync_content == b"Hello=world%21&foo=f&like=true&bar=123"
assert async_content == b"Hello=world%21&foo=f&like=true&bar=123"


@pytest.mark.anyio
Expand Down Expand Up @@ -484,3 +492,12 @@ async def hello_world() -> typing.AsyncIterator[bytes]:
def test_response_invalid_argument():
with pytest.raises(TypeError):
httpx.Response(200, content=123) # type: ignore

with pytest.raises(TypeError):
httpx.Request("GET", "", data={"hello": b""})

class AnyObject:
pass

with pytest.raises(TypeError):
httpx.Request("GET", "", data={"hello": AnyObject()})
6 changes: 5 additions & 1 deletion tests/test_multipart.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import sys
import tempfile
import typing

Expand Down Expand Up @@ -212,7 +213,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None:

url = "https://www.example.com/"
headers = {"Content-Type": "multipart/form-data; boundary=BOUNDARY"}
data = {
data: dict[str, typing.Any] = {
"a": "1",
"b": b"C",
"c": ["11", "22", "33"],
Expand Down Expand Up @@ -371,6 +372,9 @@ def test_multipart_encode_files_raises_exception_with_StringIO_content() -> None
httpx.Request("POST", url, data={}, files=files) # type: ignore


@pytest.mark.skipif(
sys.platform == "win32", reason="TemporaryFile on windows is binary mode"
)
def test_multipart_encode_files_raises_exception_with_text_mode_file() -> None:
url = "https://www.example.com"
with tempfile.TemporaryFile(mode="w") as upload:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_timeouts.py
@@ -1,3 +1,5 @@
import sys

import pytest

import httpx
Expand All @@ -12,6 +14,7 @@ async def test_read_timeout(server):
await client.get(server.url.copy_with(path="/slow_response"))


@pytest.mark.skipif(sys.platform == "win32", reason="broken on windows")
@pytest.mark.anyio
async def test_write_timeout(server):
timeout = httpx.Timeout(None, write=1e-6)
Expand All @@ -33,6 +36,7 @@ async def test_connect_timeout(server):
await client.get("http://10.255.255.1/")


@pytest.mark.skipif(sys.platform == "win32", reason="broken on windows")
@pytest.mark.anyio
async def test_pool_timeout(server):
limits = httpx.Limits(max_connections=1)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils.py
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
import sys

import certifi
import pytest
Expand Down Expand Up @@ -122,6 +123,7 @@ def test_logging_redirect_chain(server, caplog):
]


@pytest.mark.skipif(sys.platform == "win32", reason="Path separator problem")
def test_logging_ssl(caplog):
caplog.set_level(logging.DEBUG)
with httpx.Client():
Expand All @@ -142,6 +144,7 @@ def test_logging_ssl(caplog):
]


@pytest.mark.skipif(sys.platform == "win32", reason="Path separator problem")
def test_get_ssl_cert_file():
# Two environments is not set.
assert get_ca_bundle_from_env() is None
Expand Down