From 0eebb701b3a2a09055f48ec69a22c51c958b6ea6 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 17 Dec 2021 11:56:32 +0100 Subject: [PATCH 01/11] Replace HTTP client on TestClient from requests to httpx --- setup.py | 4 +- starlette/testclient.py | 505 +++++++++++------------- tests/middleware/test_cors.py | 5 +- tests/middleware/test_https_redirect.py | 8 +- tests/middleware/test_session.py | 7 +- tests/test_formparsers.py | 16 +- tests/test_requests.py | 12 +- tests/test_responses.py | 4 +- 8 files changed, 264 insertions(+), 297 deletions(-) diff --git a/setup.py b/setup.py index 3b8d32e16..12197922a 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import os import re -from setuptools import setup, find_packages +from setuptools import find_packages, setup def get_version(package): @@ -48,7 +48,7 @@ def get_long_description(): "jinja2", "python-multipart", "pyyaml", - "requests", + "httpx>=0.20.0" ] }, classifiers=[ diff --git a/starlette/testclient.py b/starlette/testclient.py index 40220fb4d..5df7f3479 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,22 +1,43 @@ import asyncio import contextlib -import http import inspect import io import json import math import queue import sys -import types -import typing from concurrent.futures import Future -from urllib.parse import unquote, urljoin, urlsplit - -import anyio.abc -import requests +from types import GeneratorType +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) +from urllib.parse import unquote, urljoin + +import anyio +import httpx from anyio.streams.stapled import StapledObjectStream +from asgiref.typing import ( + ASGI2Application as ASGI2App, + ASGI3Application as ASGI3App, + ASGIApplication, + ASGIReceiveCallable as Receive, + ASGIReceiveEvent, + ASGISendCallable as Send, + ASGISendEvent, + Scope, +) +from httpx._types import CookieTypes -from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 8): # pragma: no cover @@ -25,63 +46,7 @@ from typing_extensions import TypedDict -_PortalFactoryType = typing.Callable[ - [], typing.ContextManager[anyio.abc.BlockingPortal] -] - - -# Annotations for `Session.request()` -Cookies = typing.Union[ - typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar -] -Params = typing.Union[bytes, typing.MutableMapping[str, str]] -DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] -TimeOut = typing.Union[float, typing.Tuple[float, float]] -FileType = typing.MutableMapping[str, typing.IO] -AuthType = typing.Union[ - typing.Tuple[str, str], - requests.auth.AuthBase, - typing.Callable[[requests.PreparedRequest], requests.PreparedRequest], -] - - -ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] -ASGI2App = typing.Callable[[Scope], ASGIInstance] -ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] - - -class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key: str, default: str) -> str: - return self.getheaders(key) - - -class _MockOriginalResponse: - """ - We have to jump through some hoops to present the response as if - it was made using urllib3. - """ - - def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: - self.msg = _HeaderDict(headers) - self.closed = False - - def isclosed(self) -> bool: - return self.closed - - -class _Upgrade(Exception): - def __init__(self, session: "WebSocketTestSession") -> None: - self.session = session - - -def _get_reason_phrase(status_code: int) -> str: - try: - return http.HTTPStatus(status_code).phrase - except ValueError: - return "" - - -def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: +def _is_asgi3(app: Union[ASGI2App, ASGI3App]) -> bool: if inspect.isclass(app): return hasattr(app, "__await__") elif inspect.isfunction(app): @@ -103,12 +68,131 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +_PortalFactoryType = Callable[[], ContextManager[anyio.abc.BlockingPortal]] + + class _AsyncBackend(TypedDict): backend: str - backend_options: typing.Dict[str, typing.Any] + backend_options: Dict[str, Any] + + +class _Upgrade(Exception): + def __init__(self, session: "WebSocketTestSession") -> None: + self.session = session + + +class WebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.portal_factory = portal_factory + self._receive_queue: "queue.Queue[Any]" = queue.Queue() + self._send_queue: "queue.Queue[Any]" = queue.Queue() + + def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context(self.portal_factory()) + + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) + return self + + def __exit__(self, *args: Any) -> None: + try: + self.close(1000) + finally: + self.exit_stack.close() + while not self._send_queue.empty(): + message = self._send_queue.get() + if isinstance(message, BaseException): + raise message + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except BaseException as exc: + self._send_queue.put(exc) + raise + + async def _asgi_receive(self) -> ASGIReceiveEvent: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: ASGISendEvent) -> None: + self._send_queue.put(message) + + def _raise_on_close(self, message: Union[ASGISendEvent, ASGIReceiveEvent]) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect(message.get("code", 1000)) + + def send(self, message: ASGISendEvent) -> None: + self._receive_queue.put(message) + + def send_text(self, data: str) -> None: + self.send({"type": "websocket.receive", "text": data}) + + def send_bytes(self, data: bytes) -> None: + self.send({"type": "websocket.receive", "bytes": data}) + + def send_json(self, data: Any, mode: str = "text") -> None: + assert mode in ["text", "binary"] + text = json.dumps(data) + if mode == "text": + self.send({"type": "websocket.receive", "text": text}) + else: + self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + def close(self, code: int = 1000) -> None: + self.send({"type": "websocket.disconnect", "code": code}) + + def receive(self) -> ASGIReceiveEvent: + message = self._send_queue.get() + if isinstance(message, BaseException): + raise message + return message + + def receive_text(self) -> str: + message = self.receive() + self._raise_on_close(message) + return message["text"] + + def receive_bytes(self) -> bytes: + message = self.receive() + self._raise_on_close(message) + return message["bytes"] + def receive_json(self, mode: str = "text") -> Any: + assert mode in ["text", "binary"] + message = self.receive() + self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) -class _ASGIAdapter(requests.adapters.HTTPAdapter): + +class _TestClientTransport(httpx.BaseTransport): def __init__( self, app: ASGI3App, @@ -121,12 +205,11 @@ def __init__( self.root_path = root_path self.portal_factory = portal_factory - def send( - self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any - ) -> requests.Response: - scheme, netloc, path, query, fragment = ( - str(item) for item in urlsplit(request.url) - ) + def handle_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + query = unquote(request.url.query.decode(encoding="ascii")) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] @@ -139,7 +222,7 @@ def send( # Include the 'host' header. if "host" in request.headers: - headers: typing.List[typing.Tuple[bytes, bytes]] = [] + headers: List[Tuple[bytes, bytes]] = [] elif port == default_port: headers = [(b"host", host.encode())] else: @@ -151,12 +234,10 @@ def send( for key, value in request.headers.items() ] - scope: typing.Dict[str, typing.Any] - if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: - subprotocols: typing.Sequence[str] = [] + subprotocols: Sequence[str] = [] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { @@ -190,11 +271,11 @@ def send( request_complete = False response_started = False response_complete: anyio.Event - raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} + raw_kwargs: Dict[str, Any] = {"stream": io.BytesIO()} template = None context = None - async def receive() -> Message: + async def receive() -> ASGIReceiveEvent: nonlocal request_complete if request_complete: @@ -202,12 +283,12 @@ async def receive() -> Message: await response_complete.wait() return {"type": "http.disconnect"} - body = request.body + body = request.read() if isinstance(body, str): body_bytes: bytes = body.encode("utf-8") elif body is None: body_bytes = b"" - elif isinstance(body, types.GeneratorType): + elif isinstance(body, GeneratorType): try: chunk = body.send(None) if isinstance(chunk, str): @@ -222,24 +303,18 @@ async def receive() -> Message: request_complete = True return {"type": "http.request", "body": body_bytes} - async def send(message: Message) -> None: + async def send(message: ASGISendEvent) -> None: nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( not response_started ), 'Received multiple "http.response.start" messages.' - raw_kwargs["version"] = 11 - raw_kwargs["status"] = message["status"] - raw_kwargs["reason"] = _get_reason_phrase(message["status"]) + raw_kwargs["status_code"] = message["status"] raw_kwargs["headers"] = [ (key.decode(), value.decode()) for key, value in message.get("headers", []) ] - raw_kwargs["preload_content"] = False - raw_kwargs["original_response"] = _MockOriginalResponse( - raw_kwargs["headers"] - ) response_started = True elif message["type"] == "http.response.body": assert ( @@ -251,9 +326,9 @@ async def send(message: Message) -> None: body = message.get("body", b"") more_body = message.get("more_body", False) if request.method != "HEAD": - raw_kwargs["body"].write(body) + raw_kwargs["stream"].write(body) if not more_body: - raw_kwargs["body"].seek(0) + raw_kwargs["stream"].seek(0) response_complete.set() elif message["type"] == "http.response.template": template = message["template"] @@ -271,224 +346,108 @@ async def send(message: Message) -> None: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { - "version": 11, - "status": 500, - "reason": "Internal Server Error", + "status_code": 500, "headers": [], - "preload_content": False, - "original_response": _MockOriginalResponse([]), - "body": io.BytesIO(), + "stream": io.BytesIO(), } - raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - response = self.build_response(request, raw) + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) if template is not None: response.template = template response.context = context return response -class WebSocketTestSession: - def __init__( - self, - app: ASGI3App, - scope: Scope, - portal_factory: _PortalFactoryType, - ) -> None: - self.app = app - self.scope = scope - self.accepted_subprotocol = None - self.portal_factory = portal_factory - self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() - - def __enter__(self) -> "WebSocketTestSession": - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(self.portal_factory()) - - try: - _: "Future[None]" = self.portal.start_task_soon(self._run) - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - except Exception: - self.exit_stack.close() - raise - self.accepted_subprotocol = message.get("subprotocol", None) - return self - - def __exit__(self, *args: typing.Any) -> None: - try: - self.close(1000) - finally: - self.exit_stack.close() - while not self._send_queue.empty(): - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - - async def _run(self) -> None: - """ - The sub-thread in which the websocket session runs. - """ - scope = self.scope - receive = self._asgi_receive - send = self._asgi_send - try: - await self.app(scope, receive, send) - except BaseException as exc: - self._send_queue.put(exc) - raise - - async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - await anyio.sleep(0) - return self._receive_queue.get() - - async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) - - def _raise_on_close(self, message: Message) -> None: - if message["type"] == "websocket.close": - raise WebSocketDisconnect(message.get("code", 1000)) - - def send(self, message: Message) -> None: - self._receive_queue.put(message) - - def send_text(self, data: str) -> None: - self.send({"type": "websocket.receive", "text": data}) - - def send_bytes(self, data: bytes) -> None: - self.send({"type": "websocket.receive", "bytes": data}) - - def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] - text = json.dumps(data) - if mode == "text": - self.send({"type": "websocket.receive", "text": text}) - else: - self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) - - def close(self, code: int = 1000) -> None: - self.send({"type": "websocket.disconnect", "code": code}) - - def receive(self) -> Message: - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - return message - - def receive_text(self) -> str: - message = self.receive() - self._raise_on_close(message) - return message["text"] - - def receive_bytes(self) -> bytes: - message = self.receive() - self._raise_on_close(message) - return message["bytes"] - - def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] - message = self.receive() - self._raise_on_close(message) - if mode == "text": - text = message["text"] - else: - text = message["bytes"].decode("utf-8") - return json.loads(text) - - -class TestClient(requests.Session): - __test__ = False # For pytest to not discover this up. +class TestClient(httpx.Client): + __test__ = False task: "Future[None]" - portal: typing.Optional[anyio.abc.BlockingPortal] = None + portal: Optional[anyio.abc.BlockingPortal] = None def __init__( self, - app: typing.Union[ASGI2App, ASGI3App], + app: ASGIApplication, base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", backend: str = "asyncio", - backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + backend_options: Optional[Dict[str, Any]] = None, + cookies: CookieTypes = None, ) -> None: - super().__init__() self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} ) if _is_asgi3(app): - app = typing.cast(ASGI3App, app) + app = cast(ASGI3App, app) asgi_app = app else: - app = typing.cast(ASGI2App, app) + app = cast(ASGI2App, app) asgi_app = _WrapASGI2(app) #  type: ignore - adapter = _ASGIAdapter( - asgi_app, + self.app = asgi_app + transport = _TestClientTransport( + self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) - self.mount("http://", adapter) - self.mount("https://", adapter) - self.mount("ws://", adapter) - self.mount("wss://", adapter) - self.headers.update({"user-agent": "testclient"}) - self.app = asgi_app - self.base_url = base_url + super().__init__( + app=self.app, + base_url=base_url, + headers={"user-agent": "testclient"}, + transport=transport, + follow_redirects=True, + cookies=cookies, + ) @contextlib.contextmanager - def _portal_factory( - self, - ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: if self.portal is not None: yield self.portal else: with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal - def request( # type: ignore - self, - method: str, - url: str, - params: Params = None, - data: DataType = None, - headers: typing.MutableMapping[str, str] = None, - cookies: Cookies = None, - files: FileType = None, - auth: AuthType = None, - timeout: TimeOut = None, - allow_redirects: bool = None, - proxies: typing.MutableMapping[str, str] = None, - hooks: typing.Any = None, - stream: bool = None, - verify: typing.Union[bool, str] = None, - cert: typing.Union[str, typing.Tuple[str, str]] = None, - json: typing.Any = None, - ) -> requests.Response: - url = urljoin(self.base_url, url) - return super().request( - method, - url, - params=params, - data=data, - headers=headers, - cookies=cookies, - files=files, - auth=auth, - timeout=timeout, - allow_redirects=allow_redirects, - proxies=proxies, - hooks=hooks, - stream=stream, - verify=verify, - cert=cert, - json=json, - ) + # def request( + # self, + # method: str, + # url: httpx._types.URLTypes, + # *, + # content: httpx._types.RequestContent = None, + # data: httpx._types.RequestData = None, + # files: httpx._types.RequestFiles = None, + # json: Any = None, + # params: httpx._types.QueryParamTypes = None, + # headers: httpx._types.HeaderTypes = None, + # cookies: httpx._types.CookieTypes = None, + # auth: Union[httpx._types.AuthTypes, httpx._client.UseClientDefault] = ..., + # follow_redirects: Union[bool, httpx._client.UseClientDefault] = ..., + # timeout: Union[ + # httpx._client.TimeoutTypes, httpx._client.UseClientDefault + # ] = ..., + # extensions: dict = None, + # ) -> httpx.Response: + # # NOTE: This is not necessary. + # url = self.base_url.join(url) + # return super().request( + # method, + # url, + # content=content, + # data=data, + # files=files, + # json=json, + # params=params, + # headers=headers, + # cookies=cookies, + # auth=auth, + # follow_redirects=follow_redirects, + # timeout=timeout, + # extensions=extensions, + # ) def websocket_connect( - self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any - ) -> typing.Any: + self, url: str, subprotocols: Sequence[str] = None, **kwargs: Any + ) -> Any: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") @@ -533,7 +492,7 @@ def wait_shutdown() -> None: return self - def __exit__(self, *args: typing.Any) -> None: + def __exit__(self, *args: Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: @@ -546,7 +505,7 @@ async def lifespan(self) -> None: async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) - async def receive() -> typing.Any: + async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() @@ -561,7 +520,7 @@ async def receive() -> typing.Any: await receive() async def wait_shutdown(self) -> None: - async def receive() -> typing.Any: + async def receive() -> Any: message = await self.stream_send.receive() if message is None: self.task.result() diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 65252e502..696abe6b6 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -267,9 +267,12 @@ def homepage(request): headers = {"Origin": "https://example.org"} - for method in ("delete", "get", "head", "options", "patch", "post", "put"): + for method in ("patch", "post", "put"): response = getattr(client, method)("/", headers=headers, json={}) assert response.status_code == 200 + for method in ("delete", "get", "head", "options"): + response = getattr(client, method)("/", headers=headers) + assert response.status_code == 200 def test_cors_allow_origin_regex(test_client_factory): diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 8db950634..a69199321 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -17,21 +17,21 @@ def homepage(request): assert response.status_code == 200 client = test_client_factory(app) - response = client.get("/", allow_redirects=False) + response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:80") - response = client.get("/", allow_redirects=False) + response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:443") - response = client.get("/", allow_redirects=False) + response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:123") - response = client.get("/", allow_redirects=False) + response = client.get("/", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 42f4447e5..ab20517c1 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -1,5 +1,7 @@ import re +import pytest + from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse @@ -55,6 +57,7 @@ def test_session(test_client_factory): assert response.json() == {"session": {}} +@pytest.mark.skip("I'll fix it, just wait a sec.") def test_session_expires(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1) @@ -102,18 +105,20 @@ def test_secure_session(test_client_factory): assert response.json() == {"session": {}} +@pytest.mark.skip("I'll fix it, just wait a sec.") def test_session_cookie_subpath(test_client_factory): app = create_app() second_app = create_app() second_app.add_middleware(SessionMiddleware, secret_key="example") app.mount("/second_app", second_app) client = test_client_factory(app, base_url="http://testserver/second_app") - response = client.post("second_app/update_session", json={"some": "data"}) + response = client.post("/second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] assert cookie_path == "/second_app" +@pytest.mark.skip("I'll fix it, just wait a sec.") def test_invalid_session_cookie(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example") diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8a1174e1d..3e5dab042 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -87,7 +87,7 @@ def test_multipart_request_files(tmpdir, test_client_factory): "test": { "filename": "test.txt", "content": "", - "content_type": "", + "content_type": "text/plain", } } @@ -127,7 +127,7 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): "test1": { "filename": "test1.txt", "content": "", - "content_type": "", + "content_type": "text/plain", }, "test2": { "filename": "test2.txt", @@ -150,7 +150,7 @@ def test_multi_items(tmpdir, test_client_factory): with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", - data=[("test1", "abc")], + data={"test1": "abc"}, files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) assert response.json() == { @@ -159,7 +159,7 @@ def test_multi_items(tmpdir, test_client_factory): { "filename": "test1.txt", "content": "", - "content_type": "", + "content_type": "text/plain", }, { "filename": "test2.txt", @@ -174,7 +174,7 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", - data=( + content=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="field0"\r\n\r\n' @@ -211,7 +211,7 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory client = test_client_factory(app) response = client.post( "/", - data=( + content=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 @@ -239,7 +239,7 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact client = test_client_factory(app) response = client.post( "/", - data=( + content=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501 @@ -266,7 +266,7 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", - data=( + content=( b"--20b303e711c4ab8c443184ac833ab00f\r\n" b"Content-Disposition: form-data; " b'name="value"\r\n\r\n' diff --git a/tests/test_requests.py b/tests/test_requests.py index d7c69fbeb..fb67d4434 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -80,7 +80,7 @@ async def app(scope, receive, send): response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", data="abc") + response = client.post("/", content="abc") assert response.json() == {"body": "abc"} @@ -101,7 +101,7 @@ async def app(scope, receive, send): response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", data="abc") + response = client.post("/", content="abc") assert response.json() == {"body": "abc"} @@ -130,7 +130,7 @@ async def app(scope, receive, send): client = test_client_factory(app) - response = client.post("/", data="abc") + response = client.post("/", content="abc") assert response.json() == {"body": "abc", "stream": "abc"} @@ -149,7 +149,7 @@ async def app(scope, receive, send): client = test_client_factory(app) - response = client.post("/", data="abc") + response = client.post("/", content="abc") assert response.json() == {"body": "", "stream": "abc"} @@ -408,9 +408,9 @@ async def app(scope, receive, send): def post_body(): yield b"foo" - yield "bar" + yield b"bar" - response = client.post("/", data=post_body()) + response = client.post("/", content=post_body()) assert response.json() == {"body": "foobar"} diff --git a/tests/test_responses.py b/tests/test_responses.py index baba549ba..cad3f36e8 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -175,12 +175,12 @@ def test_response_phrase(test_client_factory): app = Response(status_code=204) client = test_client_factory(app) response = client.get("/") - assert response.reason == "No Content" + assert response.reason_phrase == "No Content" app = Response(b"", status_code=123) client = test_client_factory(app) response = client.get("/") - assert response.reason == "" + assert response.reason_phrase == "" def test_file_response(tmpdir, test_client_factory): From 067e8e6ec665a5e198c2faf75d39b9a43ef4bc9e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 17 Dec 2021 12:38:13 +0100 Subject: [PATCH 02/11] Fix mypy issues --- requirements.txt | 2 + setup.py | 2 +- starlette/testclient.py | 81 +++++++++++++++++--------------- tests/middleware/test_session.py | 2 +- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/requirements.txt b/requirements.txt index 569d3b806..745bcf886 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,5 @@ mkautodoc==0.1.0 # Packaging twine==3.7.1 wheel==0.37.0 + +asgiref \ No newline at end of file diff --git a/setup.py b/setup.py index 12197922a..b0820bebe 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import os import re -from setuptools import find_packages, setup +from setuptools import setup, find_packages def get_version(package): diff --git a/starlette/testclient.py b/starlette/testclient.py index 5df7f3479..b605c8eec 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -234,6 +234,8 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: for key, value in request.headers.items() ] + scope: Dict[str, Any] + if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: @@ -355,8 +357,8 @@ async def send(message: ASGISendEvent) -> None: response = httpx.Response(**raw_kwargs, request=request) if template is not None: - response.template = template - response.context = context + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] return response @@ -408,42 +410,45 @@ def _portal_factory(self) -> Generator[anyio.abc.BlockingPortal, None, None]: with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal - # def request( - # self, - # method: str, - # url: httpx._types.URLTypes, - # *, - # content: httpx._types.RequestContent = None, - # data: httpx._types.RequestData = None, - # files: httpx._types.RequestFiles = None, - # json: Any = None, - # params: httpx._types.QueryParamTypes = None, - # headers: httpx._types.HeaderTypes = None, - # cookies: httpx._types.CookieTypes = None, - # auth: Union[httpx._types.AuthTypes, httpx._client.UseClientDefault] = ..., - # follow_redirects: Union[bool, httpx._client.UseClientDefault] = ..., - # timeout: Union[ - # httpx._client.TimeoutTypes, httpx._client.UseClientDefault - # ] = ..., - # extensions: dict = None, - # ) -> httpx.Response: - # # NOTE: This is not necessary. - # url = self.base_url.join(url) - # return super().request( - # method, - # url, - # content=content, - # data=data, - # files=files, - # json=json, - # params=params, - # headers=headers, - # cookies=cookies, - # auth=auth, - # follow_redirects=follow_redirects, - # timeout=timeout, - # extensions=extensions, - # ) + def request( + self, + method: str, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + timeout: Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + url = self.base_url.join(url) + return super().request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) def websocket_connect( self, url: str, subprotocols: Sequence[str] = None, **kwargs: Any diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index ab20517c1..548bae011 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -112,7 +112,7 @@ def test_session_cookie_subpath(test_client_factory): second_app.add_middleware(SessionMiddleware, secret_key="example") app.mount("/second_app", second_app) client = test_client_factory(app, base_url="http://testserver/second_app") - response = client.post("/second_app/update_session", json={"some": "data"}) + response = client.post("second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] assert cookie_path == "/second_app" From 344660de572509758a823fd3a525f978cc39b68d Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jan 2022 20:00:12 +0100 Subject: [PATCH 03/11] add pragmas --- starlette/testclient.py | 12 ++++++------ tests/middleware/test_session.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 3c329bd7f..87ddc8639 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -209,9 +209,9 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: # Include the 'host' header. if "host" in request.headers: headers: typing.List[typing.Tuple[bytes, bytes]] = [] - elif port == default_port: + elif port == default_port: # pragma: no cover headers = [(b"host", host.encode())] - else: + else: # pragma: no cover headers = [(b"host", (f"{host}:{port}").encode())] # Include other request headers. @@ -273,16 +273,16 @@ async def receive() -> Message: body = request.read() if isinstance(body, str): - body_bytes: bytes = body.encode("utf-8") + body_bytes: bytes = body.encode("utf-8") # pragma: no cover elif body is None: - body_bytes = b"" + body_bytes = b"" # pragma: no cover elif isinstance(body, GeneratorType): - try: + try: # pragma: no cover chunk = body.send(None) if isinstance(chunk, str): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} - except StopIteration: + except StopIteration: # pragma: no cover request_complete = True return {"type": "http.request", "body": b""} else: diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 548bae011..89d414ff1 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -1,7 +1,5 @@ import re -import pytest - from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse @@ -57,7 +55,6 @@ def test_session(test_client_factory): assert response.json() == {"session": {}} -@pytest.mark.skip("I'll fix it, just wait a sec.") def test_session_expires(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1) @@ -70,7 +67,8 @@ def test_session_expires(test_client_factory): # fetch session id from the headers and pass it explicitly expired_cookie_header = response.headers["set-cookie"] expired_session_value = re.search(r"session=([^;]*);", expired_cookie_header)[1] - response = client.get("/view_session", cookies={"session": expired_session_value}) + client = test_client_factory(app, cookies={"session": expired_session_value}) + response = client.get("/view_session") assert response.json() == {"session": {}} @@ -105,20 +103,19 @@ def test_secure_session(test_client_factory): assert response.json() == {"session": {}} -@pytest.mark.skip("I'll fix it, just wait a sec.") def test_session_cookie_subpath(test_client_factory): app = create_app() second_app = create_app() second_app.add_middleware(SessionMiddleware, secret_key="example") app.mount("/second_app", second_app) client = test_client_factory(app, base_url="http://testserver/second_app") - response = client.post("second_app/update_session", json={"some": "data"}) + response = client.post("/second_app/update_session", json={"some": "data"}) + assert response.status_code == 200 cookie = response.headers["set-cookie"] cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] assert cookie_path == "/second_app" -@pytest.mark.skip("I'll fix it, just wait a sec.") def test_invalid_session_cookie(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example") @@ -128,5 +125,6 @@ def test_invalid_session_cookie(test_client_factory): assert response.json() == {"session": {"some": "data"}} # we expect it to not raise an exception if we provide a bogus session cookie - response = client.get("/view_session", cookies={"session": "invalid"}) + client = test_client_factory(app, cookies={"session": "invalid"}) + response = client.get("/view_session") assert response.json() == {"session": {}} From f9dcb68deada4980f97711c3f98f8fbd60208780 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 23 Jan 2022 17:10:56 +0100 Subject: [PATCH 04/11] Tests passing --- setup.py | 3 ++- starlette/_compat.py | 1 + starlette/testclient.py | 4 +++- tests/middleware/test_session.py | 5 +++-- tests/test_formparsers.py | 2 +- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 5c7618e27..27b4ce7b0 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,8 @@ def get_long_description(): "jinja2", "python-multipart", "pyyaml", - "httpx>=0.20.0" + # "httpx>=0.20.0" + "httpx @ git+https://github.com/encode/httpx.git@master", ] }, classifiers=[ diff --git a/starlette/_compat.py b/starlette/_compat.py index 116561917..d71a12ae2 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -23,6 +23,7 @@ def md5_hexdigest( data, usedforsecurity=usedforsecurity ).hexdigest() + except TypeError: # pragma: no cover def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: diff --git a/starlette/testclient.py b/starlette/testclient.py index b0fff7a24..26cf367c9 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -129,7 +129,9 @@ async def _asgi_send(self, message: Message) -> None: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": - raise WebSocketDisconnect(message.get("code", 1000), message.get("reason", "")) + raise WebSocketDisconnect( + message.get("code", 1000), message.get("reason", "") + ) def send(self, message: Message) -> None: self._receive_queue.put(message) diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 91e4531d1..ecd99273c 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -3,6 +3,7 @@ from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse +from starlette.testclient import TestClient def view_session(request): @@ -137,7 +138,7 @@ def test_invalid_session_cookie(test_client_factory): def test_session_cookie(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=None) - client = test_client_factory(app) + client: TestClient = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -146,6 +147,6 @@ def test_session_cookie(test_client_factory): set_cookie = response.headers["set-cookie"] assert "Max-Age" not in set_cookie - client.cookies.clear_session_cookies() + client.cookies.delete("session") response = client.get("/view_session") assert response.json() == {"session": {}} diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 6ecd642b0..2ffd79392 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -187,8 +187,8 @@ def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_facto "content-disposition", 'form-data; name="test2"; filename="test2.txt"', ], - ["content-type", "text/plain"], ["x-custom", "f2"], + ["content-type", "text/plain"], ], }, } From 440b78a361bfef7e03d96ab08cd4ecda33e0f981 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 23 Jan 2022 17:12:45 +0100 Subject: [PATCH 05/11] black --- starlette/_compat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/_compat.py b/starlette/_compat.py index d71a12ae2..116561917 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -23,7 +23,6 @@ def md5_hexdigest( data, usedforsecurity=usedforsecurity ).hexdigest() - except TypeError: # pragma: no cover def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: From 1ed683cf1ee731a750c9d335505fd08cd9460cd0 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 23 Jan 2022 17:19:48 +0100 Subject: [PATCH 06/11] Fix documentation --- README.md | 6 +++--- docs/index.md | 6 +++--- docs/testclient.md | 10 +++++----- requirements.txt | 1 - tests/test_staticfiles.py | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index a7d94904d..d8917e50f 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ It is production-ready, and gives you the following: * WebSocket support. * In-process background tasks. * Startup and shutdown events. -* Test client built on `requests`. +* Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. @@ -86,7 +86,7 @@ For a more complete example, see [encode/starlette-example](https://github.com/e Starlette only requires `anyio`, and the following are optional: -* [`requests`][requests] - Required if you want to use the `TestClient`. +* [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -164,7 +164,7 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

— ⭐️ —

Starlette is BSD licensed code. Designed & built in Brighton, England.

-[requests]: http://docs.python-requests.org/en/master/ +[httpx]: https://www.python-httpx.org/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ diff --git a/docs/index.md b/docs/index.md index a9ec4106f..d65bc2e6f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -27,7 +27,7 @@ It is production-ready, and gives you the following: * WebSocket support. * In-process background tasks. * Startup and shutdown events. -* Test client built on `requests`. +* Test client built on `httpx`. * CORS, GZip, Static Files, Streaming responses. * Session and Cookie support. * 100% test coverage. @@ -81,7 +81,7 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam Starlette only requires `anyio`, and the following dependencies are optional: -* [`requests`][requests] - Required if you want to use the `TestClient`. +* [`httpx`][httpx] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -159,7 +159,7 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

— ⭐️ —

Starlette is BSD licensed code. Designed & built in Brighton, England.

-[requests]: http://docs.python-requests.org/en/master/ +[httpx]: https://www.python-httpx.org/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ diff --git a/docs/testclient.md b/docs/testclient.md index a1861efec..9b186d691 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -1,6 +1,6 @@ The test client allows you to make requests against your ASGI application, -using the `requests` library. +using the `httpx` library. ```python from starlette.responses import HTMLResponse @@ -19,11 +19,11 @@ def test_app(): assert response.status_code == 200 ``` -The test client exposes the same interface as any other `requests` session. +The test client exposes the same interface as any other `httpx` session. In particular, note that the calls to make a request are just standard function calls, not awaitables. -You can use any of `requests` standard API, such as authentication, session +You can use any of `httpx` standard API, such as authentication, session cookies handling, or file uploads. By default the `TestClient` will raise any exceptions that occur in the @@ -58,7 +58,7 @@ def test_app() You can also test websocket sessions with the test client. -The `requests` library will be used to build the initial handshake, meaning you +The `httpx` library will be used to build the initial handshake, meaning you can use the same authentication options and other headers between both http and websocket testing. @@ -91,7 +91,7 @@ always raised by the test client. #### Establishing a test session -* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`. +* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `httpx.get()`. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. diff --git a/requirements.txt b/requirements.txt index 51389304e..720fb047b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,6 @@ databases[sqlite]==0.5.3 flake8==4.0.1 isort==5.10.1 mypy==0.931 -types-requests==2.26.3 types-contextvars==2.4.0 types-PyYAML==6.0.1 types-dataclasses==0.6.2 diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index bd7e0de02..e164a2a17 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -169,7 +169,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): file.write("outside root dir") app = StaticFiles(directory=directory) - # We can't test this with 'requests', so we test the app directly here. + # We can't test this with 'httpx', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} From 342b549d3a4ab2961829768bc831da5159165d6c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 24 Jan 2022 15:09:50 +0100 Subject: [PATCH 07/11] Add compat --- setup.cfg | 1 + starlette/testclient.py | 264 +++++++++++++++++++++++- tests/middleware/test_https_redirect.py | 8 +- tests/test_responses.py | 2 +- 4 files changed, 263 insertions(+), 12 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3089eaaf7..1d30c60c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,6 +29,7 @@ filterwarnings= ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio + ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning [coverage:run] source_pkgs = starlette, tests diff --git a/starlette/testclient.py b/starlette/testclient.py index 26cf367c9..7e4f98e8e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -7,6 +7,7 @@ import queue import sys import typing +import warnings from concurrent.futures import Future from types import GeneratorType from urllib.parse import unquote, urljoin @@ -14,7 +15,6 @@ import anyio import httpx from anyio.streams.stapled import StapledObjectStream -from httpx._types import CookieTypes from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -363,7 +363,7 @@ def __init__( root_path: str = "", backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, - cookies: CookieTypes = None, + cookies: httpx._client.CookieTypes = None, ) -> None: self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} @@ -398,7 +398,30 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal - def request( + def _choose_redirect_arg( + self, + follow_redirects: typing.Optional[bool], + allow_redirects: typing.Optional[bool], + ) -> typing.Union[bool, httpx._client.UseClientDefault]: + redirect: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT + if allow_redirects is not None: + message = ( + "The `allow_redirects` argument is deprecated. " + "Use `follow_redirects` instead." + ) + warnings.warn(message, DeprecationWarning) + redirect = allow_redirects + if follow_redirects is not None: + redirect = follow_redirects + elif allow_redirects is not None and follow_redirects is not None: + raise RuntimeError( # pragma: no cover + "Cannot use both `allow_redirects` and `follow_redirects`." + ) + return redirect + + def request( # type: ignore[override] self, method: str, url: httpx._types.URLTypes, @@ -413,15 +436,15 @@ def request( auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Union[ - bool, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: dict = None, ) -> httpx.Response: url = self.base_url.join(url) + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().request( method, url, @@ -433,7 +456,234 @@ def request( headers=headers, cookies=cookies, auth=auth, - follow_redirects=follow_redirects, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().head( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().post( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().put( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent = None, + data: httpx._types.RequestData = None, + files: httpx._types.RequestFiles = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().patch( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, + timeout=timeout, + extensions=extensions, + ) + + def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes = None, + headers: httpx._types.HeaderTypes = None, + cookies: httpx._types.CookieTypes = None, + auth: typing.Union[ + httpx._types.AuthTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool = None, + allow_redirects: bool = None, + timeout: typing.Union[ + httpx._client.TimeoutTypes, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict = None, + ) -> httpx.Response: + redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) + return super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=redirect, timeout=timeout, extensions=extensions, ) diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index a69199321..8db950634 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -17,21 +17,21 @@ def homepage(request): assert response.status_code == 200 client = test_client_factory(app) - response = client.get("/", follow_redirects=False) + response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:80") - response = client.get("/", follow_redirects=False) + response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:443") - response = client.get("/", follow_redirects=False) + response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" client = test_client_factory(app, base_url="http://testserver:123") - response = client.get("/", follow_redirects=False) + response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/test_responses.py b/tests/test_responses.py index 886836536..06cfe9ea2 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -83,7 +83,7 @@ async def app(scope, receive, send): await response(scope, receive, send) client: TestClient = test_client_factory(app) - response = client.request("GET", "/redirect", follow_redirects=False) + response = client.request("GET", "/redirect", allow_redirects=False) assert response.url == "http://testserver/redirect" assert response.headers["content-length"] == "0" From b32744bdc564c2929d09a59f32f9b6d7d6484e28 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 24 Jan 2022 15:13:55 +0100 Subject: [PATCH 08/11] Add `data=...` compat --- setup.cfg | 1 + tests/test_formparsers.py | 8 ++++---- tests/test_requests.py | 10 +++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1d30c60c6..0db295741 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,7 @@ filterwarnings= # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning + ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning [coverage:run] source_pkgs = starlette, tests diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 2ffd79392..e4385d07e 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -231,7 +231,7 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", - content=( + data=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="field0"\r\n\r\n' @@ -268,7 +268,7 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory client = test_client_factory(app) response = client.post( "/", - content=( + data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 @@ -296,7 +296,7 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact client = test_client_factory(app) response = client.post( "/", - content=( + data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501 @@ -323,7 +323,7 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", - content=( + data=( b"--20b303e711c4ab8c443184ac833ab00f\r\n" b"Content-Disposition: form-data; " b'name="value"\r\n\r\n' diff --git a/tests/test_requests.py b/tests/test_requests.py index fb67d4434..075dd2944 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -80,7 +80,7 @@ async def app(scope, receive, send): response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", content="abc") + response = client.post("/", data="abc") assert response.json() == {"body": "abc"} @@ -101,7 +101,7 @@ async def app(scope, receive, send): response = client.post("/", json={"a": "123"}) assert response.json() == {"body": '{"a": "123"}'} - response = client.post("/", content="abc") + response = client.post("/", data="abc") assert response.json() == {"body": "abc"} @@ -130,7 +130,7 @@ async def app(scope, receive, send): client = test_client_factory(app) - response = client.post("/", content="abc") + response = client.post("/", data="abc") assert response.json() == {"body": "abc", "stream": "abc"} @@ -149,7 +149,7 @@ async def app(scope, receive, send): client = test_client_factory(app) - response = client.post("/", content="abc") + response = client.post("/", data="abc") assert response.json() == {"body": "", "stream": "abc"} @@ -410,7 +410,7 @@ def post_body(): yield b"foo" yield b"bar" - response = client.post("/", content=post_body()) + response = client.post("/", data=post_body()) assert response.json() == {"body": "foobar"} From ae693daedf729c5011bd052e7de14a3e253ebdbd Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Feb 2022 23:07:16 +0100 Subject: [PATCH 09/11] Add raw path instead of path.encode --- starlette/testclient.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 0ebb905e9..8b9dfb6c7 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -197,6 +197,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme netloc = unquote(request.url.netloc.decode(encoding="ascii")) path = request.url.path + raw_path = request.url.raw_path query = unquote(request.url.query.decode(encoding="ascii")) default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] @@ -233,7 +234,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: scope = { "type": "websocket", "path": unquote(path), - "raw_path": path.encode(), + "raw_path": raw_path, "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), @@ -250,7 +251,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "http_version": "1.1", "method": request.method, "path": unquote(path), - "raw_path": path.encode(), + "raw_path": raw_path, "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), From 0d7af257d1df3b0aafc93f056291631b88a18c00 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 5 Sep 2022 22:12:08 +0200 Subject: [PATCH 10/11] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastián Ramírez --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index afc3a82bd..50e0dd63b 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,6 @@ in isolation. [asgi]: https://asgi.readthedocs.io/en/latest/ [httpx]: https://www.python-httpx.org/ -[requests]: http://docs.python-requests.org/en/master/ [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://pythonhosted.org/itsdangerous/ From 92b1514af4b149c3b6f66abf4f6fcb33e1b4409e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 5 Sep 2022 22:28:52 +0200 Subject: [PATCH 11/11] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 11dcdccfb..f994fc361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ full = [ "jinja2", "python-multipart", "pyyaml", - "httpx", + "httpx>=0.22.0", ] [project.urls]