From a07b468fb627b5161670e2f54e99b4e70a8594a8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 23 Jun 2021 21:14:39 +0100 Subject: [PATCH] remove monkeypatching TestClient interface --- setup.py | 5 ++++- starlette/testclient.py | 33 +++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 978a606c76..ac6479746f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,10 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0,<4"], + install_requires=[ + "anyio>=3.0.0,<4", + "typing_extensions; python_version < '3.8'", + ], extras_require={ "full": [ "graphene; python_version<'3.10'", diff --git a/starlette/testclient.py b/starlette/testclient.py index 7201809e26..1742154f2e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,4 +1,5 @@ import asyncio +import sys import contextlib import http import inspect @@ -18,6 +19,12 @@ from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar @@ -91,11 +98,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +class _AsyncBackend(TypedDict): + backend: str + backend_options: typing.Dict[str, typing.Any] + + class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, - async_backend: typing.Dict[str, typing.Any], + async_backend: _AsyncBackend, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: @@ -271,7 +283,10 @@ async def send(message: Message) -> None: class WebSocketTestSession: def __init__( - self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any] + self, + app: ASGI3App, + scope: Scope, + async_backend: _AsyncBackend, ) -> None: self.app = app self.scope = scope @@ -381,11 +396,6 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. - #: These are the default options for the constructor arguments - async_backend: typing.Dict[str, typing.Any] = { - "backend": "asyncio", - "backend_options": {}, - } task: "Future[None]" def __init__( @@ -394,14 +404,13 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: typing.Optional[str] = None, + backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() - self.async_backend = { - "backend": backend or self.async_backend["backend"], - "backend_options": backend_options or self.async_backend["backend_options"], - } + self.async_backend = _AsyncBackend( + backend=backend, backend_options=backend_options or {} + ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app