Skip to content

Commit

Permalink
remove monkeypatching TestClient interface
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Jun 23, 2021
1 parent 57f0631 commit 070fba9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -37,7 +37,10 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
package_data={"starlette": ["py.typed"]},
include_package_data=True,
install_requires=["anyio>=3.0.0,<4"],
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
],
extras_require={
"full": [
"graphene; python_version<'3.10'",
Expand Down
32 changes: 20 additions & 12 deletions starlette/testclient.py
Expand Up @@ -6,6 +6,7 @@
import json
import math
import queue
import sys
import types
import typing
from concurrent.futures import Future
Expand All @@ -18,6 +19,11 @@
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict

# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
Expand Down Expand Up @@ -91,11 +97,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await instance(receive, send)


class _AsyncBackend(TypedDict):
backend: str
backend_options: typing.Dict[str, typing.Any]


class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: typing.Dict[str, typing.Any],
async_backend: _AsyncBackend,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
Expand Down Expand Up @@ -271,7 +282,10 @@ async def send(message: Message) -> None:

class WebSocketTestSession:
def __init__(
self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any]
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
) -> None:
self.app = app
self.scope = scope
Expand Down Expand Up @@ -381,11 +395,6 @@ def receive_json(self, mode: str = "text") -> typing.Any:

class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
#: These are the default options for the constructor arguments
async_backend: typing.Dict[str, typing.Any] = {
"backend": "asyncio",
"backend_options": {},
}
task: "Future[None]"

def __init__(
Expand All @@ -394,14 +403,13 @@ def __init__(
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: typing.Optional[str] = None,
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None:
super().__init__()
self.async_backend = {
"backend": backend or self.async_backend["backend"],
"backend_options": backend_options or self.async_backend["backend_options"],
}
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
Expand Down

0 comments on commit 070fba9

Please sign in to comment.