Skip to content

Commit

Permalink
TestClient accepts backend and backend_options as arguments to constr…
Browse files Browse the repository at this point in the history
…uctor (#1211)

as opposed to ClassVar assignment 

Co-authored-by: Jamie Hewland <jhewland@gmail.com>
Co-authored-by: Jordan Speicher <jordan@jspeicher.com>
Co-authored-by: Jordan Speicher <uSpike@users.noreply.github.com>
  • Loading branch information
4 people committed Jun 28, 2021
1 parent 906e907 commit d222b87
Show file tree
Hide file tree
Showing 28 changed files with 525 additions and 485 deletions.
23 changes: 15 additions & 8 deletions docs/testclient.md
Expand Up @@ -33,18 +33,25 @@ case you should use `client = TestClient(app, raise_server_exceptions=False)`.

### Selecting the Async backend

`TestClient.async_backend` is a dictionary which allows you to set the options
for the backend used to run tests. These options are passed to
`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
for more information about backend options. By default, `asyncio` is used.
`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary).
These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
for more information about the accepted backend options.
By default, `asyncio` is used with default options.

To run `Trio`, set `async_backend["backend"] = "trio"`, for example:
To run `Trio`, pass `backend="trio"`. For example:

```python
def test_app()
client = TestClient(app)
client.async_backend["backend"] = "trio"
...
with TestClient(app, backend="trio") as client:
...
```

To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example:

```python
def test_app()
with TestClient(app, backend_options={"use_uvloop": True}) as client:
...
```

### Testing WebSocket sessions
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -37,7 +37,10 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
package_data={"starlette": ["py.typed"]},
include_package_data=True,
install_requires=["anyio>=3.0.0,<4"],
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
],
extras_require={
"full": [
"graphene; python_version<'3.10'",
Expand Down
30 changes: 21 additions & 9 deletions starlette/testclient.py
Expand Up @@ -6,6 +6,7 @@
import json
import math
import queue
import sys
import types
import typing
from concurrent.futures import Future
Expand All @@ -18,6 +19,11 @@
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
else: # pragma: no cover
from typing_extensions import TypedDict

# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
Expand Down Expand Up @@ -91,11 +97,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await instance(receive, send)


class _AsyncBackend(TypedDict):
backend: str
backend_options: typing.Dict[str, typing.Any]


class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: typing.Dict[str, typing.Any],
async_backend: _AsyncBackend,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
Expand Down Expand Up @@ -271,7 +282,10 @@ async def send(message: Message) -> None:

class WebSocketTestSession:
def __init__(
self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any]
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
) -> None:
self.app = app
self.scope = scope
Expand Down Expand Up @@ -381,13 +395,6 @@ def receive_json(self, mode: str = "text") -> typing.Any:

class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.

#: These options are passed to `anyio.start_blocking_portal()`
async_backend: typing.Dict[str, typing.Any] = {
"backend": "asyncio",
"backend_options": {},
}

task: "Future[None]"

def __init__(
Expand All @@ -396,8 +403,13 @@ def __init__(
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None:
super().__init__()
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
Expand Down
29 changes: 13 additions & 16 deletions tests/conftest.py
@@ -1,3 +1,4 @@
import functools
import sys

import pytest
Expand All @@ -7,22 +8,18 @@
collect_ignore = ["test_graphql.py"] if sys.version_info >= (3, 10) else []


@pytest.fixture(
params=[
pytest.param(
{"backend": "asyncio", "backend_options": {"use_uvloop": False}},
id="asyncio",
),
pytest.param({"backend": "trio", "backend_options": {}}, id="trio"),
],
autouse=True,
)
def anyio_backend(request, monkeypatch):
monkeypatch.setattr(TestClient, "async_backend", request.param)
return request.param["backend"]
@pytest.fixture
def no_trio_support(anyio_backend_name):
if anyio_backend_name == "trio":
pytest.skip("Trio not supported (yet!)")


@pytest.fixture
def no_trio_support(request):
if request.keywords.get("trio"):
pytest.skip("Trio not supported (yet!)")
def test_client_factory(anyio_backend_name, anyio_backend_options):
# anyio_backend_name defined by:
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
return functools.partial(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
21 changes: 10 additions & 11 deletions tests/middleware/test_base.py
Expand Up @@ -5,7 +5,6 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -48,8 +47,8 @@ async def websocket_endpoint(session):
await session.close()


def test_custom_middleware():
client = TestClient(app)
def test_custom_middleware(test_client_factory):
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -64,7 +63,7 @@ def test_custom_middleware():
assert text == "Hello, world!"


def test_middleware_decorator():
def test_middleware_decorator(test_client_factory):
app = Starlette()

@app.route("/homepage")
Expand All @@ -79,7 +78,7 @@ async def plaintext(request, call_next):
response.headers["Custom"] = "Example"
return response

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"

Expand All @@ -88,7 +87,7 @@ async def plaintext(request, call_next):
assert response.headers["Custom"] == "Example"


def test_state_data_across_multiple_middlewares():
def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value1 = "foo"
expected_value2 = "bar"

Expand Down Expand Up @@ -120,22 +119,22 @@ async def dispatch(self, request, call_next):
def homepage(request):
return PlainTextResponse("OK")

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument():
def test_app_middleware_argument(test_client_factory):
def homepage(request):
return PlainTextResponse("Homepage")

app = Starlette(
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
)

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -145,7 +144,7 @@ def test_middleware_repr():
assert repr(middleware) == "Middleware(CustomMiddleware)"


def test_fully_evaluated_response():
def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
Expand All @@ -155,6 +154,6 @@ async def dispatch(self, request, call_next):
app = Starlette()
app.add_middleware(CustomMiddleware)

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

0 comments on commit d222b87

Please sign in to comment.