diff --git a/README.md b/README.md
index 44bd55c77..50e0dd63b 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ It is production-ready, and gives you the following:
* WebSocket support.
* In-process background tasks.
* Startup and shutdown events.
-* Test client built on `requests`.
+* Test client built on `httpx`.
* CORS, GZip, Static Files, Streaming responses.
* Session and Cookie support.
* 100% test coverage.
@@ -87,7 +87,7 @@ For a more complete example, see [encode/starlette-example](https://github.com/e
Starlette only requires `anyio`, and the following are optional:
-* [`requests`][requests] - Required if you want to use the `TestClient`.
+* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -134,7 +134,7 @@ in isolation.
Starlette is BSD licensed code.
Designed & crafted with care.— ⭐️ —
[asgi]: https://asgi.readthedocs.io/en/latest/
-[requests]: http://docs.python-requests.org/en/master/
+[httpx]: https://www.python-httpx.org/
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[itsdangerous]: https://pythonhosted.org/itsdangerous/
diff --git a/docs/index.md b/docs/index.md
index 5a5014021..9f1977875 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -27,7 +27,7 @@ It is production-ready, and gives you the following:
* WebSocket support.
* In-process background tasks.
* Startup and shutdown events.
-* Test client built on `requests`.
+* Test client built on `httpx`.
* CORS, GZip, Static Files, Streaming responses.
* Session and Cookie support.
* 100% test coverage.
@@ -83,7 +83,7 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam
Starlette only requires `anyio`, and the following dependencies are optional:
-* [`requests`][requests] - Required if you want to use the `TestClient`.
+* [`httpx`][httpx] - Required if you want to use the `TestClient`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -130,7 +130,7 @@ in isolation.
Starlette is BSD licensed code.
Designed & crafted with care.— ⭐️ —
[asgi]: https://asgi.readthedocs.io/en/latest/
-[requests]: http://docs.python-requests.org/en/master/
+[httpx]: https://www.python-httpx.org/
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[itsdangerous]: https://pythonhosted.org/itsdangerous/
diff --git a/docs/testclient.md b/docs/testclient.md
index f64c570bb..053b42005 100644
--- a/docs/testclient.md
+++ b/docs/testclient.md
@@ -1,6 +1,6 @@
The test client allows you to make requests against your ASGI application,
-using the `requests` library.
+using the `httpx` library.
```python
from starlette.responses import HTMLResponse
@@ -19,11 +19,11 @@ def test_app():
assert response.status_code == 200
```
-The test client exposes the same interface as any other `requests` session.
+The test client exposes the same interface as any other `httpx` session.
In particular, note that the calls to make a request are just standard
function calls, not awaitables.
-You can use any of `requests` standard API, such as authentication, session
+You can use any of `httpx` standard API, such as authentication, session
cookies handling, or file uploads.
For example, to set headers on the TestClient you can do:
@@ -96,7 +96,7 @@ def test_app()
You can also test websocket sessions with the test client.
-The `requests` library will be used to build the initial handshake, meaning you
+The `httpx` library will be used to build the initial handshake, meaning you
can use the same authentication options and other headers between both http and
websocket testing.
@@ -129,7 +129,7 @@ always raised by the test client.
#### Establishing a test session
-* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`.
+* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`.
May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection.
diff --git a/pyproject.toml b/pyproject.toml
index 7bbce89d9..f994fc361 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,7 +37,7 @@ full = [
"jinja2",
"python-multipart",
"pyyaml",
- "requests",
+ "httpx>=0.22.0",
]
[project.urls]
diff --git a/requirements.txt b/requirements.txt
index 648f0fa01..0b54fa596 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,7 +10,6 @@ flake8==3.9.2
isort==5.10.1
mypy==0.971
typing_extensions==4.3.0
-types-requests==2.26.3
types-contextvars==2.4.7
types-PyYAML==6.0.11
types-dataclasses==0.6.6
diff --git a/setup.cfg b/setup.cfg
index 93f27e4e0..23cf32cc0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -32,6 +32,9 @@ filterwarnings=
ignore: starlette\.middleware\.wsgi is deprecated and will be removed in a future release\.*:DeprecationWarning
ignore: Async generator 'starlette\.requests\.Request\.stream' was garbage collected before it had been exhausted.*:ResourceWarning
ignore: path is deprecated.*:DeprecationWarning:certifi
+ ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning
+ ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning
+ ignore: 'cgi' is deprecated and slated for removal in Python 3\.13:DeprecationWarning
[coverage:run]
source_pkgs = starlette, tests
diff --git a/starlette/testclient.py b/starlette/testclient.py
index efe2b493b..455440ce5 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -1,22 +1,22 @@
import contextlib
-import http
import inspect
import io
import json
import math
import queue
import sys
-import types
import typing
+import warnings
from concurrent.futures import Future
-from urllib.parse import unquote, urljoin, urlsplit
+from types import GeneratorType
+from urllib.parse import unquote, urljoin
-import anyio.abc
-import requests
+import anyio
+import httpx
from anyio.streams.stapled import StapledObjectStream
from starlette._utils import is_async_callable
-from starlette.types import Message, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
if sys.version_info >= (3, 8): # pragma: no cover
@@ -24,63 +24,15 @@
else: # pragma: no cover
from typing_extensions import TypedDict
-
_PortalFactoryType = typing.Callable[
[], typing.ContextManager[anyio.abc.BlockingPortal]
]
-
-# Annotations for `Session.request()`
-Cookies = typing.Union[
- typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
-]
-Params = typing.Union[bytes, typing.MutableMapping[str, str]]
-DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
-TimeOut = typing.Union[float, typing.Tuple[float, float]]
-FileType = typing.MutableMapping[str, typing.IO]
-AuthType = typing.Union[
- typing.Tuple[str, str],
- requests.auth.AuthBase,
- typing.Callable[[requests.PreparedRequest], requests.PreparedRequest],
-]
-
-
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
-class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
- def get_all(self, key: str, default: str) -> str:
- return self.getheaders(key)
-
-
-class _MockOriginalResponse:
- """
- We have to jump through some hoops to present the response as if
- it was made using urllib3.
- """
-
- def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
- self.msg = _HeaderDict(headers)
- self.closed = False
-
- def isclosed(self) -> bool:
- return self.closed
-
-
-class _Upgrade(Exception):
- def __init__(self, session: "WebSocketTestSession") -> None:
- self.session = session
-
-
-def _get_reason_phrase(status_code: int) -> str:
- try:
- return http.HTTPStatus(status_code).phrase
- except ValueError:
- return ""
-
-
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
@@ -105,7 +57,127 @@ class _AsyncBackend(TypedDict):
backend_options: typing.Dict[str, typing.Any]
-class _ASGIAdapter(requests.adapters.HTTPAdapter):
+class _Upgrade(Exception):
+ def __init__(self, session: "WebSocketTestSession") -> None:
+ self.session = session
+
+
+class WebSocketTestSession:
+ def __init__(
+ self,
+ app: ASGI3App,
+ scope: Scope,
+ portal_factory: _PortalFactoryType,
+ ) -> None:
+ self.app = app
+ self.scope = scope
+ self.accepted_subprotocol = None
+ self.portal_factory = portal_factory
+ self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
+ self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
+ self.extra_headers = None
+
+ def __enter__(self) -> "WebSocketTestSession":
+ self.exit_stack = contextlib.ExitStack()
+ self.portal = self.exit_stack.enter_context(self.portal_factory())
+
+ try:
+ _: "Future[None]" = self.portal.start_task_soon(self._run)
+ self.send({"type": "websocket.connect"})
+ message = self.receive()
+ self._raise_on_close(message)
+ except Exception:
+ self.exit_stack.close()
+ raise
+ self.accepted_subprotocol = message.get("subprotocol", None)
+ self.extra_headers = message.get("headers", None)
+ return self
+
+ def __exit__(self, *args: typing.Any) -> None:
+ try:
+ self.close(1000)
+ finally:
+ self.exit_stack.close()
+ while not self._send_queue.empty():
+ message = self._send_queue.get()
+ if isinstance(message, BaseException):
+ raise message
+
+ async def _run(self) -> None:
+ """
+ The sub-thread in which the websocket session runs.
+ """
+ scope = self.scope
+ receive = self._asgi_receive
+ send = self._asgi_send
+ try:
+ await self.app(scope, receive, send)
+ except BaseException as exc:
+ self._send_queue.put(exc)
+ raise
+
+ async def _asgi_receive(self) -> Message:
+ while self._receive_queue.empty():
+ await anyio.sleep(0)
+ return self._receive_queue.get()
+
+ async def _asgi_send(self, message: Message) -> None:
+ self._send_queue.put(message)
+
+ def _raise_on_close(self, message: Message) -> None:
+ if message["type"] == "websocket.close":
+ raise WebSocketDisconnect(
+ message.get("code", 1000), message.get("reason", "")
+ )
+
+ def send(self, message: Message) -> None:
+ self._receive_queue.put(message)
+
+ def send_text(self, data: str) -> None:
+ self.send({"type": "websocket.receive", "text": data})
+
+ def send_bytes(self, data: bytes) -> None:
+ self.send({"type": "websocket.receive", "bytes": data})
+
+ def send_json(self, data: typing.Any, mode: str = "text") -> None:
+ assert mode in ["text", "binary"]
+ text = json.dumps(data)
+ if mode == "text":
+ self.send({"type": "websocket.receive", "text": text})
+ else:
+ self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
+
+ def close(self, code: int = 1000) -> None:
+ self.send({"type": "websocket.disconnect", "code": code})
+
+ def receive(self) -> Message:
+ message = self._send_queue.get()
+ if isinstance(message, BaseException):
+ raise message
+ return message
+
+ def receive_text(self) -> str:
+ message = self.receive()
+ self._raise_on_close(message)
+ return message["text"]
+
+ def receive_bytes(self) -> bytes:
+ message = self.receive()
+ self._raise_on_close(message)
+ return message["bytes"]
+
+ def receive_json(self, mode: str = "text") -> typing.Any:
+ assert mode in ["text", "binary"]
+ message = self.receive()
+ self._raise_on_close(message)
+ if mode == "text":
+ text = message["text"]
+ else:
+ text = message["bytes"].decode("utf-8")
+ return json.loads(text)
+
+
+class _TestClientTransport(httpx.BaseTransport):
def __init__(
self,
app: ASGI3App,
@@ -118,12 +190,12 @@ def __init__(
self.root_path = root_path
self.portal_factory = portal_factory
- def send(
- self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
- ) -> requests.Response:
- scheme, netloc, path, query, fragment = (
- str(item) for item in urlsplit(request.url)
- )
+ def handle_request(self, request: httpx.Request) -> httpx.Response:
+ scheme = request.url.scheme
+ netloc = unquote(request.url.netloc.decode(encoding="ascii"))
+ path = request.url.path
+ raw_path = request.url.raw_path
+ query = unquote(request.url.query.decode(encoding="ascii"))
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
@@ -137,9 +209,9 @@ def send(
# Include the 'host' header.
if "host" in request.headers:
headers: typing.List[typing.Tuple[bytes, bytes]] = []
- elif port == default_port:
+ elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
- else:
+ else: # pragma: no cover
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
@@ -159,7 +231,7 @@ def send(
scope = {
"type": "websocket",
"path": unquote(path),
- "raw_path": path.encode(),
+ "raw_path": raw_path,
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
@@ -176,7 +248,7 @@ def send(
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
- "raw_path": path.encode(),
+ "raw_path": raw_path,
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
@@ -189,7 +261,7 @@ def send(
request_complete = False
response_started = False
response_complete: anyio.Event
- raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
+ raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
template = None
context = None
@@ -201,18 +273,18 @@ async def receive() -> Message:
await response_complete.wait()
return {"type": "http.disconnect"}
- body = request.body
+ body = request.read()
if isinstance(body, str):
- body_bytes: bytes = body.encode("utf-8")
+ body_bytes: bytes = body.encode("utf-8") # pragma: no cover
elif body is None:
- body_bytes = b""
- elif isinstance(body, types.GeneratorType):
- try:
+ body_bytes = b"" # pragma: no cover
+ elif isinstance(body, GeneratorType):
+ try: # pragma: no cover
chunk = body.send(None)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
- except StopIteration:
+ except StopIteration: # pragma: no cover
request_complete = True
return {"type": "http.request", "body": b""}
else:
@@ -228,17 +300,11 @@ async def send(message: Message) -> None:
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
- raw_kwargs["version"] = 11
- raw_kwargs["status"] = message["status"]
- raw_kwargs["reason"] = _get_reason_phrase(message["status"])
+ raw_kwargs["status_code"] = message["status"]
raw_kwargs["headers"] = [
(key.decode(), value.decode())
for key, value in message.get("headers", [])
]
- raw_kwargs["preload_content"] = False
- raw_kwargs["original_response"] = _MockOriginalResponse(
- raw_kwargs["headers"]
- )
response_started = True
elif message["type"] == "http.response.body":
assert (
@@ -250,9 +316,9 @@ async def send(message: Message) -> None:
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
- raw_kwargs["body"].write(body)
+ raw_kwargs["stream"].write(body)
if not more_body:
- raw_kwargs["body"].seek(0)
+ raw_kwargs["stream"].seek(0)
response_complete.set()
elif message["type"] == "http.response.template":
template = message["template"]
@@ -270,153 +336,35 @@ async def send(message: Message) -> None:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
- "version": 11,
- "status": 500,
- "reason": "Internal Server Error",
+ "status_code": 500,
"headers": [],
- "preload_content": False,
- "original_response": _MockOriginalResponse([]),
- "body": io.BytesIO(),
+ "stream": io.BytesIO(),
}
- raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
- response = self.build_response(request, raw)
+ raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
+
+ response = httpx.Response(**raw_kwargs, request=request)
if template is not None:
- response.template = template
- response.context = context
+ response.template = template # type: ignore[attr-defined]
+ response.context = context # type: ignore[attr-defined]
return response
-class WebSocketTestSession:
- def __init__(
- self,
- app: ASGI3App,
- scope: Scope,
- portal_factory: _PortalFactoryType,
- ) -> None:
- self.app = app
- self.scope = scope
- self.accepted_subprotocol = None
- self.extra_headers = None
- self.portal_factory = portal_factory
- self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
- self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
-
- def __enter__(self) -> "WebSocketTestSession":
- self.exit_stack = contextlib.ExitStack()
- self.portal = self.exit_stack.enter_context(self.portal_factory())
-
- try:
- _: "Future[None]" = self.portal.start_task_soon(self._run)
- self.send({"type": "websocket.connect"})
- message = self.receive()
- self._raise_on_close(message)
- except Exception:
- self.exit_stack.close()
- raise
- self.accepted_subprotocol = message.get("subprotocol", None)
- self.extra_headers = message.get("headers", None)
- return self
-
- def __exit__(self, *args: typing.Any) -> None:
- try:
- self.close(1000)
- finally:
- self.exit_stack.close()
- while not self._send_queue.empty():
- message = self._send_queue.get()
- if isinstance(message, BaseException):
- raise message
-
- async def _run(self) -> None:
- """
- The sub-thread in which the websocket session runs.
- """
- scope = self.scope
- receive = self._asgi_receive
- send = self._asgi_send
- try:
- await self.app(scope, receive, send)
- except BaseException as exc:
- self._send_queue.put(exc)
- raise
-
- async def _asgi_receive(self) -> Message:
- while self._receive_queue.empty():
- await anyio.sleep(0)
- return self._receive_queue.get()
-
- async def _asgi_send(self, message: Message) -> None:
- self._send_queue.put(message)
-
- def _raise_on_close(self, message: Message) -> None:
- if message["type"] == "websocket.close":
- raise WebSocketDisconnect(
- message.get("code", 1000), message.get("reason", "")
- )
-
- def send(self, message: Message) -> None:
- self._receive_queue.put(message)
-
- def send_text(self, data: str) -> None:
- self.send({"type": "websocket.receive", "text": data})
-
- def send_bytes(self, data: bytes) -> None:
- self.send({"type": "websocket.receive", "bytes": data})
-
- def send_json(self, data: typing.Any, mode: str = "text") -> None:
- assert mode in ["text", "binary"]
- text = json.dumps(data)
- if mode == "text":
- self.send({"type": "websocket.receive", "text": text})
- else:
- self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
-
- def close(self, code: int = 1000) -> None:
- self.send({"type": "websocket.disconnect", "code": code})
-
- def receive(self) -> Message:
- message = self._send_queue.get()
- if isinstance(message, BaseException):
- raise message
- return message
-
- def receive_text(self) -> str:
- message = self.receive()
- self._raise_on_close(message)
- return message["text"]
-
- def receive_bytes(self) -> bytes:
- message = self.receive()
- self._raise_on_close(message)
- return message["bytes"]
-
- def receive_json(self, mode: str = "text") -> typing.Any:
- assert mode in ["text", "binary"]
- message = self.receive()
- self._raise_on_close(message)
- if mode == "text":
- text = message["text"]
- else:
- text = message["bytes"].decode("utf-8")
- return json.loads(text)
-
-
-class TestClient(requests.Session):
- __test__ = False # For pytest to not discover this up.
+class TestClient(httpx.Client):
+ __test__ = False
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None
def __init__(
self,
- app: typing.Union[ASGI2App, ASGI3App],
+ app: ASGIApp,
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
+ cookies: httpx._client.CookieTypes = None,
) -> None:
- super().__init__()
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
@@ -424,69 +372,320 @@ def __init__(
app = typing.cast(ASGI3App, app)
asgi_app = app
else:
- app = typing.cast(ASGI2App, app)
- asgi_app = _WrapASGI2(app) # type: ignore
- adapter = _ASGIAdapter(
- asgi_app,
+ app = typing.cast(ASGI2App, app) # type: ignore[assignment]
+ asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
+ self.app = asgi_app
+ transport = _TestClientTransport(
+ self.app,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
- self.mount("http://", adapter)
- self.mount("https://", adapter)
- self.mount("ws://", adapter)
- self.mount("wss://", adapter)
- self.headers.update({"user-agent": "testclient"})
- self.app = asgi_app
- self.base_url = base_url
+ super().__init__(
+ app=self.app,
+ base_url=base_url,
+ headers={"user-agent": "testclient"},
+ transport=transport,
+ follow_redirects=True,
+ cookies=cookies,
+ )
@contextlib.contextmanager
- def _portal_factory(
- self,
- ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.start_blocking_portal(**self.async_backend) as portal:
yield portal
- def request( # type: ignore
+ def _choose_redirect_arg(
+ self,
+ follow_redirects: typing.Optional[bool],
+ allow_redirects: typing.Optional[bool],
+ ) -> typing.Union[bool, httpx._client.UseClientDefault]:
+ redirect: typing.Union[
+ bool, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT
+ if allow_redirects is not None:
+ message = (
+ "The `allow_redirects` argument is deprecated. "
+ "Use `follow_redirects` instead."
+ )
+ warnings.warn(message, DeprecationWarning)
+ redirect = allow_redirects
+ if follow_redirects is not None:
+ redirect = follow_redirects
+ elif allow_redirects is not None and follow_redirects is not None:
+ raise RuntimeError( # pragma: no cover
+ "Cannot use both `allow_redirects` and `follow_redirects`."
+ )
+ return redirect
+
+ def request( # type: ignore[override]
self,
method: str,
- url: str,
- params: Params = None,
- data: DataType = None,
- headers: typing.MutableMapping[str, str] = None,
- cookies: Cookies = None,
- files: FileType = None,
- auth: AuthType = None,
- timeout: TimeOut = None,
- allow_redirects: bool = None,
- proxies: typing.MutableMapping[str, str] = None,
- hooks: typing.Any = None,
- stream: bool = None,
- verify: typing.Union[bool, str] = None,
- cert: typing.Union[str, typing.Tuple[str, str]] = None,
+ url: httpx._types.URLTypes,
+ *,
+ content: httpx._types.RequestContent = None,
+ data: httpx._types.RequestData = None,
+ files: httpx._types.RequestFiles = None,
json: typing.Any = None,
- ) -> requests.Response:
- url = urljoin(self.base_url, url)
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ url = self.base_url.join(url)
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
return super().request(
method,
+ url,
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def get( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().get(
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def options( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().options(
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def head( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().head(
url,
params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def post( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ content: httpx._types.RequestContent = None,
+ data: httpx._types.RequestData = None,
+ files: httpx._types.RequestFiles = None,
+ json: typing.Any = None,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().post(
+ url,
+ content=content,
data=data,
+ files=files,
+ json=json,
+ params=params,
headers=headers,
cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def put( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ content: httpx._types.RequestContent = None,
+ data: httpx._types.RequestData = None,
+ files: httpx._types.RequestFiles = None,
+ json: typing.Any = None,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().put(
+ url,
+ content=content,
+ data=data,
files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
auth=auth,
+ follow_redirects=redirect,
timeout=timeout,
- allow_redirects=allow_redirects,
- proxies=proxies,
- hooks=hooks,
- stream=stream,
- verify=verify,
- cert=cert,
+ extensions=extensions,
+ )
+
+ def patch( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ content: httpx._types.RequestContent = None,
+ data: httpx._types.RequestData = None,
+ files: httpx._types.RequestFiles = None,
+ json: typing.Any = None,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().patch(
+ url,
+ content=content,
+ data=data,
+ files=files,
json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
+ )
+
+ def delete( # type: ignore[override]
+ self,
+ url: httpx._types.URLTypes,
+ *,
+ params: httpx._types.QueryParamTypes = None,
+ headers: httpx._types.HeaderTypes = None,
+ cookies: httpx._types.CookieTypes = None,
+ auth: typing.Union[
+ httpx._types.AuthTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ follow_redirects: bool = None,
+ allow_redirects: bool = None,
+ timeout: typing.Union[
+ httpx._client.TimeoutTypes, httpx._client.UseClientDefault
+ ] = httpx._client.USE_CLIENT_DEFAULT,
+ extensions: dict = None,
+ ) -> httpx.Response:
+ redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
+ return super().delete(
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=redirect,
+ timeout=timeout,
+ extensions=extensions,
)
def websocket_connect(
diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py
index 910afd9f8..ca3d4f47b 100644
--- a/tests/middleware/test_cors.py
+++ b/tests/middleware/test_cors.py
@@ -279,9 +279,12 @@ def homepage(request):
headers = {"Origin": "https://example.org"}
- for method in ("delete", "get", "head", "options", "patch", "post", "put"):
+ for method in ("patch", "post", "put"):
response = getattr(client, method)("/", headers=headers, json={})
assert response.status_code == 200
+ for method in ("delete", "get", "head", "options"):
+ response = getattr(client, method)("/", headers=headers)
+ assert response.status_code == 200
def test_cors_allow_origin_regex(test_client_factory):
diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py
index a044153a6..3f43506c4 100644
--- a/tests/middleware/test_session.py
+++ b/tests/middleware/test_session.py
@@ -5,6 +5,7 @@
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
+from starlette.testclient import TestClient
def view_session(request):
@@ -74,7 +75,8 @@ def test_session_expires(test_client_factory):
expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header)
assert expired_session_match is not None
expired_session_value = expired_session_match[1]
- response = client.get("/view_session", cookies={"session": expired_session_value})
+ client = test_client_factory(app, cookies={"session": expired_session_value})
+ response = client.get("/view_session")
assert response.json() == {"session": {}}
@@ -128,7 +130,8 @@ def test_session_cookie_subpath(test_client_factory):
)
app = Starlette(routes=[Mount("/second_app", app=second_app)])
client = test_client_factory(app, base_url="http://testserver/second_app")
- response = client.post("second_app/update_session", json={"some": "data"})
+ response = client.post("/second_app/update_session", json={"some": "data"})
+ assert response.status_code == 200
cookie = response.headers["set-cookie"]
cookie_path_match = re.search(r"; path=(\S+);", cookie)
assert cookie_path_match is not None
@@ -150,7 +153,8 @@ def test_invalid_session_cookie(test_client_factory):
assert response.json() == {"session": {"some": "data"}}
# we expect it to not raise an exception if we provide a bogus session cookie
- response = client.get("/view_session", cookies={"session": "invalid"})
+ client = test_client_factory(app, cookies={"session": "invalid"})
+ response = client.get("/view_session")
assert response.json() == {"session": {}}
@@ -162,7 +166,7 @@ def test_session_cookie(test_client_factory):
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)],
)
- client = test_client_factory(app)
+ client: TestClient = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
@@ -171,6 +175,6 @@ def test_session_cookie(test_client_factory):
set_cookie = response.headers["set-cookie"]
assert "Max-Age" not in set_cookie
- client.cookies.clear_session_cookies()
+ client.cookies.delete("session")
response = client.get("/view_session")
assert response.json() == {"session": {}}
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index b7f8cad8c..4792424ab 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -9,7 +9,6 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
-from starlette.testclient import TestClient
class ForceMultipartDict(dict):
@@ -114,7 +113,7 @@ def test_multipart_request_files(tmpdir, test_client_factory):
"test": {
"filename": "test.txt",
"content": "",
- "content_type": "",
+ "content_type": "text/plain",
}
}
@@ -154,7 +153,7 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory):
"test1": {
"filename": "test1.txt",
"content": "",
- "content_type": "",
+ "content_type": "text/plain",
},
"test2": {
"filename": "test2.txt",
@@ -193,8 +192,8 @@ def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_facto
"content-disposition",
'form-data; name="test2"; filename="test2.txt"',
],
- ["content-type", "text/plain"],
["x-custom", "f2"],
+ ["content-type", "text/plain"],
],
},
}
@@ -213,7 +212,7 @@ def test_multi_items(tmpdir, test_client_factory):
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/",
- data=[("test1", "abc")],
+ data={"test1": "abc"},
files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))],
)
assert response.json() == {
@@ -222,7 +221,7 @@ def test_multi_items(tmpdir, test_client_factory):
{
"filename": "test1.txt",
"content": "",
- "content_type": "",
+ "content_type": "text/plain",
},
{
"filename": "test2.txt",
@@ -401,9 +400,7 @@ def test_user_safe_decode_ignores_wrong_charset():
(Starlette(routes=[Mount("/", app=app)]), does_not_raise()),
],
)
-def test_missing_boundary_parameter(
- app, expectation, test_client_factory: typing.Callable[..., TestClient]
-) -> None:
+def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None:
client = test_client_factory(app)
with expectation:
res = client.post(
@@ -428,7 +425,7 @@ def test_missing_boundary_parameter(
],
)
def test_missing_name_parameter_on_content_disposition(
- app, expectation, test_client_factory: typing.Callable[..., TestClient]
+ app, expectation, test_client_factory
):
client = test_client_factory(app)
with expectation:
diff --git a/tests/test_requests.py b/tests/test_requests.py
index 033df1e6a..7422ad72a 100644
--- a/tests/test_requests.py
+++ b/tests/test_requests.py
@@ -431,7 +431,7 @@ async def app(scope, receive, send):
def post_body():
yield b"foo"
- yield "bar"
+ yield b"bar"
response = client.post("/", data=post_body())
assert response.json() == {"body": "foobar"}
diff --git a/tests/test_responses.py b/tests/test_responses.py
index 2030a7382..608842da2 100644
--- a/tests/test_responses.py
+++ b/tests/test_responses.py
@@ -191,12 +191,12 @@ def test_response_phrase(test_client_factory):
app = Response(status_code=204)
client = test_client_factory(app)
response = client.get("/")
- assert response.reason == "No Content"
+ assert response.reason_phrase == "No Content"
app = Response(b"", status_code=123)
client = test_client_factory(app)
response = client.get("/")
- assert response.reason == ""
+ assert response.reason_phrase == ""
def test_file_response(tmpdir, test_client_factory):
diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py
index 84f1f5d46..142c2a00b 100644
--- a/tests/test_staticfiles.py
+++ b/tests/test_staticfiles.py
@@ -171,7 +171,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
file.write("outside root dir")
app = StaticFiles(directory=directory)
- # We can't test this with 'requests', so we test the app directly here.
+ # We can't test this with 'httpx', so we test the app directly here.
path = app.get_path({"path": "/../example.txt"})
scope = {"method": "GET"}