diff --git a/README.md b/README.md index 44bd55c77e..50e0dd63ba 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ It is production-ready, and gives you the following: * WebSocket support. * In-process background tasks. * Startup and shutdown events. -* Test client built on `requests`. +* Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. @@ -87,7 +87,7 @@ For a more complete example, see [encode/starlette-example](https://github.com/e Starlette only requires `anyio`, and the following are optional: -* [`requests`][requests] - Required if you want to use the `TestClient`. +* [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -134,7 +134,7 @@ in isolation.

Starlette is BSD licensed code.
Designed & crafted with care.

— ⭐️ —

[asgi]: https://asgi.readthedocs.io/en/latest/ -[requests]: http://docs.python-requests.org/en/master/ +[httpx]: https://www.python-httpx.org/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ diff --git a/docs/index.md b/docs/index.md index 5a50140213..9f19778751 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,7 +27,7 @@ It is production-ready, and gives you the following: * WebSocket support. * In-process background tasks. * Startup and shutdown events. -* Test client built on `requests`. +* Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. @@ -83,7 +83,7 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam Starlette only requires `anyio`, and the following dependencies are optional: -* [`requests`][requests] - Required if you want to use the `TestClient`. +* [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -130,7 +130,7 @@ in isolation.

Starlette is BSD licensed code.
Designed & crafted with care.

— ⭐️ —

[asgi]: https://asgi.readthedocs.io/en/latest/ -[requests]: http://docs.python-requests.org/en/master/ +[httpx]: https://www.python-httpx.org/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ diff --git a/docs/testclient.md b/docs/testclient.md index f64c570bb7..053b420055 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -1,6 +1,6 @@ The test client allows you to make requests against your ASGI application, -using the `requests` library. +using the `httpx` library. ```python from starlette.responses import HTMLResponse @@ -19,11 +19,11 @@ def test_app(): assert response.status_code == 200 ``` -The test client exposes the same interface as any other `requests` session. +The test client exposes the same interface as any other `httpx` session. In particular, note that the calls to make a request are just standard function calls, not awaitables. -You can use any of `requests` standard API, such as authentication, session +You can use any of `httpx` standard API, such as authentication, session cookies handling, or file uploads. For example, to set headers on the TestClient you can do: @@ -96,7 +96,7 @@ def test_app() You can also test websocket sessions with the test client. -The `requests` library will be used to build the initial handshake, meaning you +The `httpx` library will be used to build the initial handshake, meaning you can use the same authentication options and other headers between both http and websocket testing. @@ -129,7 +129,7 @@ always raised by the test client. #### Establishing a test session -* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`. +* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. diff --git a/pyproject.toml b/pyproject.toml index 7bbce89d93..f994fc3613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ full = [ "jinja2", "python-multipart", "pyyaml", - "requests", + "httpx>=0.22.0", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 648f0fa018..0b54fa596f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,6 @@ flake8==3.9.2 isort==5.10.1 mypy==0.971 typing_extensions==4.3.0 -types-requests==2.26.3 types-contextvars==2.4.7 types-PyYAML==6.0.11 types-dataclasses==0.6.6 diff --git a/setup.cfg b/setup.cfg index 93f27e4e08..23cf32cc03 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,9 @@ filterwarnings= ignore: starlette\.middleware\.wsgi is deprecated and will be removed in a future release\.*:DeprecationWarning ignore: Async generator 'starlette\.requests\.Request\.stream' was garbage collected before it had been exhausted.*:ResourceWarning ignore: path is deprecated.*:DeprecationWarning:certifi + ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning + ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning + ignore: 'cgi' is deprecated and slated for removal in Python 3\.13:DeprecationWarning [coverage:run] source_pkgs = starlette, tests diff --git a/starlette/testclient.py b/starlette/testclient.py index efe2b493bb..455440ce59 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,22 +1,22 @@ import contextlib -import http import inspect import io import json import math import queue import sys -import types import typing +import warnings from concurrent.futures import Future -from urllib.parse import unquote, urljoin, urlsplit +from types import GeneratorType +from urllib.parse import unquote, urljoin -import anyio.abc -import requests +import anyio +import httpx from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable -from starlette.types import Message, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 8): # pragma: no cover @@ -24,63 +24,15 @@ else: # pragma: no cover from typing_extensions import TypedDict - _PortalFactoryType = typing.Callable[ [], typing.ContextManager[anyio.abc.BlockingPortal] ] - -# Annotations for `Session.request()` -Cookies = typing.Union[ - typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar -] -Params = typing.Union[bytes, typing.MutableMapping[str, str]] -DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] -TimeOut = typing.Union[float, typing.Tuple[float, float]] -FileType = typing.MutableMapping[str, typing.IO] -AuthType = typing.Union[ - typing.Tuple[str, str], - requests.auth.AuthBase, - typing.Callable[[requests.PreparedRequest], requests.PreparedRequest], -] - - ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] ASGI2App = typing.Callable[[Scope], ASGIInstance] ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] -class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key: str, default: str) -> str: - return self.getheaders(key) - - -class _MockOriginalResponse: - """ - We have to jump through some hoops to present the response as if - it was made using urllib3. - """ - - def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: - self.msg = _HeaderDict(headers) - self.closed = False - - def isclosed(self) -> bool: - return self.closed - - -class _Upgrade(Exception): - def __init__(self, session: "WebSocketTestSession") -> None: - self.session = session - - -def _get_reason_phrase(status_code: int) -> str: - try: - return http.HTTPStatus(status_code).phrase - except ValueError: - return "" - - def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: if inspect.isclass(app): return hasattr(app, "__await__") @@ -105,7 +57,127 @@ class _AsyncBackend(TypedDict): backend_options: typing.Dict[str, typing.Any] -class _ASGIAdapter(requests.adapters.HTTPAdapter): +class _Upgrade(Exception): + def __init__(self, session: "WebSocketTestSession") -> None: + self.session = session + + +class WebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.portal_factory = portal_factory + self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() + self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() + self.extra_headers = None + + def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context(self.portal_factory()) + + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + return self + + def __exit__(self, *args: typing.Any) -> None: + try: + self.close(1000) + finally: + self.exit_stack.close() + while not self._send_queue.empty(): + message = self._send_queue.get() + if isinstance(message, BaseException): + raise message + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except BaseException as exc: + self._send_queue.put(exc) + raise + + async def _asgi_receive(self) -> Message: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + self._send_queue.put(message) + + def _raise_on_close(self, message: Message) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect( + message.get("code", 1000), message.get("reason", "") + ) + + def send(self, message: Message) -> None: + self._receive_queue.put(message) + + def send_text(self, data: str) -> None: + self.send({"type": "websocket.receive", "text": data}) + + def send_bytes(self, data: bytes) -> None: + self.send({"type": "websocket.receive", "bytes": data}) + + def send_json(self, data: typing.Any, mode: str = "text") -> None: + assert mode in ["text", "binary"] + text = json.dumps(data) + if mode == "text": + self.send({"type": "websocket.receive", "text": text}) + else: + self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + def close(self, code: int = 1000) -> None: + self.send({"type": "websocket.disconnect", "code": code}) + + def receive(self) -> Message: + message = self._send_queue.get() + if isinstance(message, BaseException): + raise message + return message + + def receive_text(self) -> str: + message = self.receive() + self._raise_on_close(message) + return message["text"] + + def receive_bytes(self) -> bytes: + message = self.receive() + self._raise_on_close(message) + return message["bytes"] + + def receive_json(self, mode: str = "text") -> typing.Any: + assert mode in ["text", "binary"] + message = self.receive() + self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + +class _TestClientTransport(httpx.BaseTransport): def __init__( self, app: ASGI3App, @@ -118,12 +190,12 @@ def __init__( self.root_path = root_path self.portal_factory = portal_factory - def send( - self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any - ) -> requests.Response: - scheme, netloc, path, query, fragment = ( - str(item) for item in urlsplit(request.url) - ) + def handle_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + raw_path = request.url.raw_path + query = unquote(request.url.query.decode(encoding="ascii")) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] @@ -137,9 +209,9 @@ def send( # Include the 'host' header. if "host" in request.headers: headers: typing.List[typing.Tuple[bytes, bytes]] = [] - elif port == default_port: + elif port == default_port: # pragma: no cover headers = [(b"host", host.encode())] - else: + else: # pragma: no cover headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. @@ -159,7 +231,7 @@ def send( scope = { "type": "websocket", "path": unquote(path), - "raw_path": path.encode(), + "raw_path": raw_path, "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), @@ -176,7 +248,7 @@ def send( "http_version": "1.1", "method": request.method, "path": unquote(path), - "raw_path": path.encode(), + "raw_path": raw_path, "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), @@ -189,7 +261,7 @@ def send( request_complete = False response_started = False response_complete: anyio.Event - raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} + raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()} template = None context = None @@ -201,18 +273,18 @@ async def receive() -> Message: await response_complete.wait() return {"type": "http.disconnect"} - body = request.body + body = request.read() if isinstance(body, str): - body_bytes: bytes = body.encode("utf-8") + body_bytes: bytes = body.encode("utf-8") # pragma: no cover elif body is None: - body_bytes = b"" - elif isinstance(body, types.GeneratorType): - try: + body_bytes = b"" # pragma: no cover + elif isinstance(body, GeneratorType): + try: # pragma: no cover chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} - except StopIteration: + except StopIteration: # pragma: no cover request_complete = True return {"type": "http.request", "body": b""} else: @@ -228,17 +300,11 @@ async def send(message: Message) -> None: assert ( not response_started ), 'Received multiple "http.response.start" messages.' - raw_kwargs["version"] = 11 - raw_kwargs["status"] = message["status"] - raw_kwargs["reason"] = _get_reason_phrase(message["status"]) + raw_kwargs["status_code"] = message["status"] raw_kwargs["headers"] = [ (key.decode(), value.decode()) for key, value in message.get("headers", []) ] - raw_kwargs["preload_content"] = False - raw_kwargs["original_response"] = _MockOriginalResponse( - raw_kwargs["headers"] - ) response_started = True elif message["type"] == "http.response.body": assert ( @@ -250,9 +316,9 @@ async def send(message: Message) -> None: body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": - raw_kwargs["body"].write(body) + raw_kwargs["stream"].write(body) if not more_body: - raw_kwargs["body"].seek(0) + raw_kwargs["stream"].seek(0) response_complete.set() elif message["type"] == "http.response.template": template = message["template"] @@ -270,153 +336,35 @@ async def send(message: Message) -> None: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { - "version": 11, - "status": 500, - "reason": "Internal Server Error", + "status_code": 500, "headers": [], - "preload_content": False, - "original_response": _MockOriginalResponse([]), - "body": io.BytesIO(), + "stream": io.BytesIO(), } - raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - response = self.build_response(request, raw) + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) if template is not None: - response.template = template - response.context = context + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] return response -class WebSocketTestSession: - def __init__( - self, - app: ASGI3App, - scope: Scope, - portal_factory: _PortalFactoryType, - ) -> None: - self.app = app - self.scope = scope - self.accepted_subprotocol = None - self.extra_headers = None - self.portal_factory = portal_factory - self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() - - def __enter__(self) -> "WebSocketTestSession": - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(self.portal_factory()) - - try: - _: "Future[None]" = self.portal.start_task_soon(self._run) - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - except Exception: - self.exit_stack.close() - raise - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - return self - - def __exit__(self, *args: typing.Any) -> None: - try: - self.close(1000) - finally: - self.exit_stack.close() - while not self._send_queue.empty(): - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - - async def _run(self) -> None: - """ - The sub-thread in which the websocket session runs. - """ - scope = self.scope - receive = self._asgi_receive - send = self._asgi_send - try: - await self.app(scope, receive, send) - except BaseException as exc: - self._send_queue.put(exc) - raise - - async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - await anyio.sleep(0) - return self._receive_queue.get() - - async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) - - def _raise_on_close(self, message: Message) -> None: - if message["type"] == "websocket.close": - raise WebSocketDisconnect( - message.get("code", 1000), message.get("reason", "") - ) - - def send(self, message: Message) -> None: - self._receive_queue.put(message) - - def send_text(self, data: str) -> None: - self.send({"type": "websocket.receive", "text": data}) - - def send_bytes(self, data: bytes) -> None: - self.send({"type": "websocket.receive", "bytes": data}) - - def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] - text = json.dumps(data) - if mode == "text": - self.send({"type": "websocket.receive", "text": text}) - else: - self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) - - def close(self, code: int = 1000) -> None: - self.send({"type": "websocket.disconnect", "code": code}) - - def receive(self) -> Message: - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - return message - - def receive_text(self) -> str: - message = self.receive() - self._raise_on_close(message) - return message["text"] - - def receive_bytes(self) -> bytes: - message = self.receive() - self._raise_on_close(message) - return message["bytes"] - - def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] - message = self.receive() - self._raise_on_close(message) - if mode == "text": - text = message["text"] - else: - text = message["bytes"].decode("utf-8") - return json.loads(text) - - -class TestClient(requests.Session): - __test__ = False # For pytest to not discover this up. +class TestClient(httpx.Client): + __test__ = False task: "Future[None]" portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, - app: typing.Union[ASGI2App, ASGI3App], + app: ASGIApp, base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + cookies: httpx._client.CookieTypes = None, ) -> None: - super().__init__() self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} ) @@ -424,69 +372,320 @@ def __init__( app = typing.cast(ASGI3App, app) asgi_app = app else: - app = typing.cast(ASGI2App, app) - asgi_app = _WrapASGI2(app) #  type: ignore - adapter = _ASGIAdapter( - asgi_app, + app = typing.cast(ASGI2App, app) # type: ignore[assignment] + asgi_app = _WrapASGI2(app) # type: ignore[arg-type] + self.app = asgi_app + transport = _TestClientTransport( + self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) - self.mount("http://", adapter) - self.mount("https://", adapter) - self.mount("ws://", adapter) - self.mount("wss://", adapter) - self.headers.update({"user-agent": "testclient"}) - self.app = asgi_app - self.base_url = base_url + super().__init__( + app=self.app, + base_url=base_url, + headers={"user-agent": "testclient"}, + transport=transport, + follow_redirects=True, + cookies=cookies, + ) @contextlib.contextmanager - def _portal_factory( - self, - ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal - def request( # type: ignore + def _choose_redirect_arg( + self, + follow_redirects: typing.Optional[bool], + allow_redirects: typing.Optional[bool], + ) -> typing.Union[bool, httpx._client.UseClientDefault]: + redirect: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT + if allow_redirects is not None: + message = ( + "The `allow_redirects` argument is deprecated. " + "Use `follow_redirects` instead." + ) + warnings.warn(message, DeprecationWarning) + redirect = allow_redirects + if follow_redirects is not None: + redirect = follow_redirects + elif allow_redirects is not None and follow_redirects is not None: + raise RuntimeError( # pragma: no cover + "Cannot use both `allow_redirects` and `follow_redirects`." + ) + return redirect + + def request( # type: ignore[override] self, method: str, - url: str, - params: Params = None, - data: DataType = None, - headers: typing.MutableMapping[str, str] = None, - cookies: Cookies = None, - files: FileType = None, - auth: AuthType = None, - timeout: TimeOut = None, - allow_redirects: bool = None, - proxies: typing.MutableMapping[str, str] = None, - hooks: typing.Any = None, - stream: bool = None, - verify: typing.Union[bool, str] = None, - cert: typing.Union[str, typing.Tuple[str, str]] = None, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, json: typing.Any = None, - ) -> requests.Response: - url = urljoin(self.base_url, url) + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + url = self.base_url.join(url) + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().request( method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().head( url, params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().post( + url, + content=content, data=data, + files=files, + json=json, + params=params, headers=headers, cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().put( + url, + content=content, + data=data, files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, auth=auth, + follow_redirects=redirect, timeout=timeout, - allow_redirects=allow_redirects, - proxies=proxies, - hooks=hooks, - stream=stream, - verify=verify, - cert=cert, + extensions=extensions, + ) + + def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().patch( + url, + content=content, + data=data, + files=files, json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, ) def websocket_connect( diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 910afd9f84..ca3d4f47b0 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -279,9 +279,12 @@ def homepage(request): headers = {"Origin": "https://example.org"} - for method in ("delete", "get", "head", "options", "patch", "post", "put"): + for method in ("patch", "post", "put"): response = getattr(client, method)("/", headers=headers, json={}) assert response.status_code == 200 + for method in ("delete", "get", "head", "options"): + response = getattr(client, method)("/", headers=headers) + assert response.status_code == 200 def test_cors_allow_origin_regex(test_client_factory): diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index a044153a66..3f43506c41 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -5,6 +5,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse from starlette.routing import Mount, Route +from starlette.testclient import TestClient def view_session(request): @@ -74,7 +75,8 @@ def test_session_expires(test_client_factory): expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) assert expired_session_match is not None expired_session_value = expired_session_match[1] - response = client.get("/view_session", cookies={"session": expired_session_value}) + client = test_client_factory(app, cookies={"session": expired_session_value}) + response = client.get("/view_session") assert response.json() == {"session": {}} @@ -128,7 +130,8 @@ def test_session_cookie_subpath(test_client_factory): ) app = Starlette(routes=[Mount("/second_app", app=second_app)]) client = test_client_factory(app, base_url="http://testserver/second_app") - response = client.post("second_app/update_session", json={"some": "data"}) + response = client.post("/second_app/update_session", json={"some": "data"}) + assert response.status_code == 200 cookie = response.headers["set-cookie"] cookie_path_match = re.search(r"; path=(\S+);", cookie) assert cookie_path_match is not None @@ -150,7 +153,8 @@ def test_invalid_session_cookie(test_client_factory): assert response.json() == {"session": {"some": "data"}} # we expect it to not raise an exception if we provide a bogus session cookie - response = client.get("/view_session", cookies={"session": "invalid"}) + client = test_client_factory(app, cookies={"session": "invalid"}) + response = client.get("/view_session") assert response.json() == {"session": {}} @@ -162,7 +166,7 @@ def test_session_cookie(test_client_factory): ], middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=None)], ) - client = test_client_factory(app) + client: TestClient = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -171,6 +175,6 @@ def test_session_cookie(test_client_factory): set_cookie = response.headers["set-cookie"] assert "Max-Age" not in set_cookie - client.cookies.clear_session_cookies() + client.cookies.delete("session") response = client.get("/view_session") assert response.json() == {"session": {}} diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index b7f8cad8c8..4792424abc 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -9,7 +9,6 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount -from starlette.testclient import TestClient class ForceMultipartDict(dict): @@ -114,7 +113,7 @@ def test_multipart_request_files(tmpdir, test_client_factory): "test": { "filename": "test.txt", "content": "", - "content_type": "", + "content_type": "text/plain", } } @@ -154,7 +153,7 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): "test1": { "filename": "test1.txt", "content": "", - "content_type": "", + "content_type": "text/plain", }, "test2": { "filename": "test2.txt", @@ -193,8 +192,8 @@ def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_facto "content-disposition", 'form-data; name="test2"; filename="test2.txt"', ], - ["content-type", "text/plain"], ["x-custom", "f2"], + ["content-type", "text/plain"], ], }, } @@ -213,7 +212,7 @@ def test_multi_items(tmpdir, test_client_factory): with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", - data=[("test1", "abc")], + data={"test1": "abc"}, files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) assert response.json() == { @@ -222,7 +221,7 @@ def test_multi_items(tmpdir, test_client_factory): { "filename": "test1.txt", "content": "", - "content_type": "", + "content_type": "text/plain", }, { "filename": "test2.txt", @@ -401,9 +400,7 @@ def test_user_safe_decode_ignores_wrong_charset(): (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) -def test_missing_boundary_parameter( - app, expectation, test_client_factory: typing.Callable[..., TestClient] -) -> None: +def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None: client = test_client_factory(app) with expectation: res = client.post( @@ -428,7 +425,7 @@ def test_missing_boundary_parameter( ], ) def test_missing_name_parameter_on_content_disposition( - app, expectation, test_client_factory: typing.Callable[..., TestClient] + app, expectation, test_client_factory ): client = test_client_factory(app) with expectation: diff --git a/tests/test_requests.py b/tests/test_requests.py index 033df1e6aa..7422ad72a9 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -431,7 +431,7 @@ async def app(scope, receive, send): def post_body(): yield b"foo" - yield "bar" + yield b"bar" response = client.post("/", data=post_body()) assert response.json() == {"body": "foobar"} diff --git a/tests/test_responses.py b/tests/test_responses.py index 2030a73824..608842da2e 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -191,12 +191,12 @@ def test_response_phrase(test_client_factory): app = Response(status_code=204) client = test_client_factory(app) response = client.get("/") - assert response.reason == "No Content" + assert response.reason_phrase == "No Content" app = Response(b"", status_code=123) client = test_client_factory(app) response = client.get("/") - assert response.reason == "" + assert response.reason_phrase == "" def test_file_response(tmpdir, test_client_factory): diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 84f1f5d46a..142c2a00b5 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -171,7 +171,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): file.write("outside root dir") app = StaticFiles(directory=directory) - # We can't test this with 'requests', so we test the app directly here. + # We can't test this with 'httpx', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"}