Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TestClient accepts backend and backend_options as arguments to constructor #1211

Merged
merged 8 commits into from Jun 28, 2021
23 changes: 15 additions & 8 deletions docs/testclient.md
Expand Up @@ -33,18 +33,25 @@ case you should use `client = TestClient(app, raise_server_exceptions=False)`.

### Selecting the Async backend

`TestClient.async_backend` is a dictionary which allows you to set the options
for the backend used to run tests. These options are passed to
`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
for more information about backend options. By default, `asyncio` is used.
`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary).
These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
for more information about the accepted backend options.
By default, `asyncio` is used with default options.

To run `Trio`, set `async_backend["backend"] = "trio"`, for example:
To run `Trio`, pass `backend="trio"`. For example:

```python
def test_app()
client = TestClient(app)
client.async_backend["backend"] = "trio"
...
with TestClient(app, backend="trio") as client:
...
```

To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example:

```python
def test_app()
with TestClient(app, backend_options={"use_uvloop": True}) as client:
...
```

### Testing WebSocket sessions
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Expand Up @@ -37,7 +37,10 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
package_data={"starlette": ["py.typed"]},
include_package_data=True,
install_requires=["anyio>=3.0.0,<4"],
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
graingert marked this conversation as resolved.
Show resolved Hide resolved
],
extras_require={
"full": [
"graphene; python_version<'3.10'",
Expand Down
30 changes: 21 additions & 9 deletions starlette/testclient.py
Expand Up @@ -6,6 +6,7 @@
import json
import math
import queue
import sys
import types
import typing
from concurrent.futures import Future
Expand All @@ -18,6 +19,11 @@
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict # pragma: no cover
else: # pragma: no cover
from typing_extensions import TypedDict # pragma: no cover
graingert marked this conversation as resolved.
Show resolved Hide resolved

# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
Expand Down Expand Up @@ -91,11 +97,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await instance(receive, send)


class _AsyncBackend(TypedDict):
graingert marked this conversation as resolved.
Show resolved Hide resolved
backend: str
backend_options: typing.Dict[str, typing.Any]


class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: typing.Dict[str, typing.Any],
async_backend: _AsyncBackend,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
Expand Down Expand Up @@ -271,7 +282,10 @@ async def send(message: Message) -> None:

class WebSocketTestSession:
def __init__(
self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any]
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
) -> None:
self.app = app
self.scope = scope
Expand Down Expand Up @@ -381,13 +395,6 @@ def receive_json(self, mode: str = "text") -> typing.Any:

class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.

#: These options are passed to `anyio.start_blocking_portal()`
async_backend: typing.Dict[str, typing.Any] = {
"backend": "asyncio",
"backend_options": {},
}

task: "Future[None]"

def __init__(
Expand All @@ -396,8 +403,13 @@ def __init__(
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None:
super().__init__()
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
Expand Down
27 changes: 11 additions & 16 deletions tests/conftest.py
@@ -1,3 +1,4 @@
import functools
import sys

import pytest
Expand All @@ -7,22 +8,16 @@
collect_ignore = ["test_graphql.py"] if sys.version_info >= (3, 10) else []


@pytest.fixture(
params=[
graingert marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(
{"backend": "asyncio", "backend_options": {"use_uvloop": False}},
id="asyncio",
),
pytest.param({"backend": "trio", "backend_options": {}}, id="trio"),
],
autouse=True,
)
def anyio_backend(request, monkeypatch):
monkeypatch.setattr(TestClient, "async_backend", request.param)
return request.param["backend"]
@pytest.fixture
def no_trio_support(anyio_backend_name):
if anyio_backend_name == "trio":
pytest.skip("Trio not supported (yet!)")
graingert marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def no_trio_support(request):
if request.keywords.get("trio"):
pytest.skip("Trio not supported (yet!)")
def test_client_factory(anyio_backend_name, anyio_backend_options):
graingert marked this conversation as resolved.
Show resolved Hide resolved
return functools.partial(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
21 changes: 10 additions & 11 deletions tests/middleware/test_base.py
Expand Up @@ -5,7 +5,6 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -48,8 +47,8 @@ async def websocket_endpoint(session):
await session.close()


def test_custom_middleware():
client = TestClient(app)
def test_custom_middleware(test_client_factory):
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -64,7 +63,7 @@ def test_custom_middleware():
assert text == "Hello, world!"


def test_middleware_decorator():
def test_middleware_decorator(test_client_factory):
app = Starlette()

@app.route("/homepage")
Expand All @@ -79,7 +78,7 @@ async def plaintext(request, call_next):
response.headers["Custom"] = "Example"
return response

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"

Expand All @@ -88,7 +87,7 @@ async def plaintext(request, call_next):
assert response.headers["Custom"] == "Example"


def test_state_data_across_multiple_middlewares():
def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value1 = "foo"
expected_value2 = "bar"

Expand Down Expand Up @@ -120,22 +119,22 @@ async def dispatch(self, request, call_next):
def homepage(request):
return PlainTextResponse("OK")

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument():
def test_app_middleware_argument(test_client_factory):
def homepage(request):
return PlainTextResponse("Homepage")

app = Starlette(
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
)

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

Expand All @@ -145,7 +144,7 @@ def test_middleware_repr():
assert repr(middleware) == "Middleware(CustomMiddleware)"


def test_fully_evaluated_response():
def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
Expand All @@ -155,6 +154,6 @@ async def dispatch(self, request, call_next):
app = Starlette()
app.add_middleware(CustomMiddleware)

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"