From a90343bcf1f0184ac03187130891ff72bd6c628a Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 23 Jun 2021 11:35:25 -0500 Subject: [PATCH 1/5] TestClient accepts backend and backend_options as arguments to constructor --- docs/testclient.md | 23 +++++++++++++++-------- starlette/testclient.py | 10 +++++++--- tests/test_testclient.py | 19 +++++++++++++++++++ 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/docs/testclient.md b/docs/testclient.md index f37858401..a1861efec 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -33,18 +33,25 @@ case you should use `client = TestClient(app, raise_server_exceptions=False)`. ### Selecting the Async backend -`TestClient.async_backend` is a dictionary which allows you to set the options -for the backend used to run tests. These options are passed to -`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) -for more information about backend options. By default, `asyncio` is used. +`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary). +These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) +for more information about the accepted backend options. +By default, `asyncio` is used with default options. -To run `Trio`, set `async_backend["backend"] = "trio"`, for example: +To run `Trio`, pass `backend="trio"`. For example: ```python def test_app() - client = TestClient(app) - client.async_backend["backend"] = "trio" - ... + with TestClient(app, backend="trio") as client: + ... +``` + +To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example: + +```python +def test_app() + with TestClient(app, backend_options={"use_uvloop": True}) as client: + ... ``` ### Testing WebSocket sessions diff --git a/starlette/testclient.py b/starlette/testclient.py index c1c0fe165..7cf58e057 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -380,13 +380,11 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. - - #: These options are passed to `anyio.start_blocking_portal()` + #: These are the default options for the constructor arguments async_backend: typing.Dict[str, typing.Any] = { "backend": "asyncio", "backend_options": {}, } - task: "Future[None]" def __init__( @@ -395,8 +393,14 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", + backend: typing.Optional[str] = None, + backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() + self.async_backend = { + "backend": backend or self.async_backend["backend"], + "backend_options": backend_options or self.async_backend["backend_options"], + } if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 86f36e172..44e3320a4 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -132,3 +132,22 @@ async def asgi(receive, send): with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} + + +def test_backend_name(request): + """ + Test that the tests are defaulting to the correct backend and that a new + instance of TestClient can be created using different backend options. + """ + # client created using monkeypatched async_backend + client1 = TestClient(mock_service) + if "trio" in request.keywords: + client2 = TestClient(mock_service, backend="asyncio") + assert client1.async_backend["backend"] == "trio" + assert client2.async_backend["backend"] == "asyncio" + elif "asyncio" in request.keywords: + client2 = TestClient(mock_service, backend="trio") + assert client1.async_backend["backend"] == "asyncio" + assert client2.async_backend["backend"] == "trio" + else: + pytest.fail("Unknown backend") # pragma: nocover From 57f06312d8ccd26a344c80516b1f9a988c3cccde Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 23 Jun 2021 21:02:59 +0100 Subject: [PATCH 2/5] avoid monkeypatching TestClient --- tests/conftest.py | 27 +++---- tests/middleware/test_base.py | 21 +++--- tests/middleware/test_cors.py | 67 +++++++++-------- tests/middleware/test_errors.py | 21 +++--- tests/middleware/test_gzip.py | 17 +++-- tests/middleware/test_https_redirect.py | 13 ++-- tests/middleware/test_session.py | 19 +++-- tests/middleware/test_trusted_host.py | 13 ++-- tests/middleware/test_wsgi.py | 19 +++-- tests/test_applications.py | 64 +++++++++-------- tests/test_authentication.py | 21 +++--- tests/test_background.py | 13 ++-- tests/test_database.py | 15 ++-- tests/test_endpoints.py | 41 ++++++----- tests/test_exceptions.py | 22 +++--- tests/test_formparsers.py | 61 ++++++++-------- tests/test_graphql.py | 44 ++++++------ tests/test_requests.py | 96 +++++++++++++------------ tests/test_responses.py | 79 ++++++++++---------- tests/test_routing.py | 80 +++++++++++---------- tests/test_schemas.py | 5 +- tests/test_staticfiles.py | 67 +++++++++-------- tests/test_templates.py | 5 +- tests/test_testclient.py | 62 ++++++---------- tests/test_websockets.py | 77 ++++++++++---------- 25 files changed, 483 insertions(+), 486 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9ed420305..acaea4c87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import functools import sys import pytest @@ -7,22 +8,16 @@ collect_ignore = ["test_graphql.py"] if sys.version_info >= (3, 10) else [] -@pytest.fixture( - params=[ - pytest.param( - {"backend": "asyncio", "backend_options": {"use_uvloop": False}}, - id="asyncio", - ), - pytest.param({"backend": "trio", "backend_options": {}}, id="trio"), - ], - autouse=True, -) -def anyio_backend(request, monkeypatch): - monkeypatch.setattr(TestClient, "async_backend", request.param) - return request.param["backend"] +@pytest.fixture +def no_trio_support(anyio_backend_name): + if anyio_backend_name == "trio": + pytest.skip("Trio not supported (yet!)") @pytest.fixture -def no_trio_support(request): - if request.keywords.get("trio"): - pytest.skip("Trio not supported (yet!)") +def test_client_factory(anyio_backend_name, anyio_backend_options): + return functools.partial( + TestClient, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index df8901934..8a8df4ea6 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,7 +5,6 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient class CustomMiddleware(BaseHTTPMiddleware): @@ -48,8 +47,8 @@ async def websocket_endpoint(session): await session.close() -def test_custom_middleware(): - client = TestClient(app) +def test_custom_middleware(test_client_factory): + client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -64,7 +63,7 @@ def test_custom_middleware(): assert text == "Hello, world!" -def test_middleware_decorator(): +def test_middleware_decorator(test_client_factory): app = Starlette() @app.route("/homepage") @@ -79,7 +78,7 @@ async def plaintext(request, call_next): response.headers["Custom"] = "Example" return response - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "OK" @@ -88,7 +87,7 @@ async def plaintext(request, call_next): assert response.headers["Custom"] == "Example" -def test_state_data_across_multiple_middlewares(): +def test_state_data_across_multiple_middlewares(test_client_factory): expected_value1 = "foo" expected_value2 = "bar" @@ -120,14 +119,14 @@ async def dispatch(self, request, call_next): def homepage(request): return PlainTextResponse("OK") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "OK" assert response.headers["X-State-Foo"] == expected_value1 assert response.headers["X-State-Bar"] == expected_value2 -def test_app_middleware_argument(): +def test_app_middleware_argument(test_client_factory): def homepage(request): return PlainTextResponse("Homepage") @@ -135,7 +134,7 @@ def homepage(request): routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -145,7 +144,7 @@ def test_middleware_repr(): assert repr(middleware) == "Middleware(CustomMiddleware)" -def test_fully_evaluated_response(): +def test_fully_evaluated_response(test_client_factory): # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -155,6 +154,6 @@ async def dispatch(self, request, call_next): app = Starlette() app.add_middleware(CustomMiddleware) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/does_not_exist") assert response.text == "Custom" diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 266ebca5b..65252e502 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_cors_allow_all(): +def test_cors_allow_all(test_client_factory): app = Starlette() app.add_middleware( @@ -20,7 +19,7 @@ def test_cors_allow_all(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -61,7 +60,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_all_except_credentials(): +def test_cors_allow_all_except_credentials(test_client_factory): app = Starlette() app.add_middleware( @@ -76,7 +75,7 @@ def test_cors_allow_all_except_credentials(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -108,7 +107,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin(): +def test_cors_allow_specific_origin(test_client_factory): app = Starlette() app.add_middleware( @@ -121,7 +120,7 @@ def test_cors_allow_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -153,7 +152,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight(): +def test_cors_disallowed_preflight(test_client_factory): app = Starlette() app.add_middleware( @@ -166,7 +165,7 @@ def test_cors_disallowed_preflight(): def homepage(request): pass # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -190,7 +189,9 @@ def homepage(request): assert response.text == "Disallowed CORS headers" -def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(): +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( + test_client_factory, +): app = Starlette() app.add_middleware( @@ -204,7 +205,7 @@ def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_all def homepage(request): return # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -221,7 +222,7 @@ def homepage(request): assert response.headers["vary"] == "Origin" -def test_cors_preflight_allow_all_methods(): +def test_cors_preflight_allow_all_methods(test_client_factory): app = Starlette() app.add_middleware( @@ -234,7 +235,7 @@ def test_cors_preflight_allow_all_methods(): def homepage(request): pass # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) headers = { "Origin": "https://example.org", @@ -247,7 +248,7 @@ def homepage(request): assert method in response.headers["access-control-allow-methods"] -def test_cors_allow_all_methods(): +def test_cors_allow_all_methods(test_client_factory): app = Starlette() app.add_middleware( @@ -262,7 +263,7 @@ def test_cors_allow_all_methods(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) headers = {"Origin": "https://example.org"} @@ -271,7 +272,7 @@ def homepage(request): assert response.status_code == 200 -def test_cors_allow_origin_regex(): +def test_cors_allow_origin_regex(test_client_factory): app = Starlette() app.add_middleware( @@ -285,7 +286,7 @@ def test_cors_allow_origin_regex(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test standard response headers = {"Origin": "https://example.org"} @@ -339,7 +340,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_origin_regex_fullmatch(): +def test_cors_allow_origin_regex_fullmatch(test_client_factory): app = Starlette() app.add_middleware( @@ -352,7 +353,7 @@ def test_cors_allow_origin_regex_fullmatch(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test standard response headers = {"Origin": "https://subdomain.example.org"} @@ -373,7 +374,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_credentialed_requests_return_specific_origin(): +def test_cors_credentialed_requests_return_specific_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -382,7 +383,7 @@ def test_cors_credentialed_requests_return_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test credentialed request headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} @@ -393,7 +394,7 @@ def homepage(request): assert "access-control-allow-credentials" not in response.headers -def test_cors_vary_header_defaults_to_origin(): +def test_cors_vary_header_defaults_to_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) @@ -404,14 +405,14 @@ def test_cors_vary_header_defaults_to_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers=headers) assert response.status_code == 200 assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_not_set_for_non_credentialed_request(): +def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -422,14 +423,14 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding" -def test_cors_vary_header_is_properly_set_for_credentialed_request(): +def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -440,7 +441,7 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} @@ -449,7 +450,9 @@ def homepage(request): assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(): +def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( + test_client_factory, +): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) @@ -460,14 +463,16 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): +def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( + test_client_factory, +): app = Starlette() app.add_middleware( @@ -478,7 +483,7 @@ def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 28b2a7ba3..2c926a9b2 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -2,10 +2,9 @@ from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -def test_handler(): +def test_handler(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") @@ -13,49 +12,49 @@ def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_debug_text(): +def test_debug_text(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.headers["content-type"].startswith("text/plain") assert "RuntimeError: Something went wrong" in response.text -def test_debug_html(): +def test_debug_html(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 assert response.headers["content-type"].startswith("text/html") assert "RuntimeError" in response.text -def test_debug_after_response_sent(): +def test_debug_after_response_sent(test_client_factory): async def app(scope, receive, send): response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_debug_not_http(): +def test_debug_not_http(test_client_factory): """ DebugMiddleware should just pass through any non-http messages as-is. """ @@ -66,6 +65,6 @@ async def app(scope, receive, send): app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/"): pass # pragma: nocover diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index cd989b8c1..b917ea4db 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.gzip import GZipMiddleware from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.testclient import TestClient -def test_gzip_responses(): +def test_gzip_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -13,7 +12,7 @@ def test_gzip_responses(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -21,7 +20,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_not_in_accept_encoding(): +def test_gzip_not_in_accept_encoding(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -30,7 +29,7 @@ def test_gzip_not_in_accept_encoding(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -38,7 +37,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 -def test_gzip_ignored_for_small_responses(): +def test_gzip_ignored_for_small_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -47,7 +46,7 @@ def test_gzip_ignored_for_small_responses(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "OK" @@ -55,7 +54,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 2 -def test_gzip_streaming_response(): +def test_gzip_streaming_response(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -69,7 +68,7 @@ async def generator(bytes, count): streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 757770b85..8db950634 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_https_redirect_middleware(): +def test_https_redirect_middleware(test_client_factory): app = Starlette() app.add_middleware(HTTPSRedirectMiddleware) @@ -13,26 +12,26 @@ def test_https_redirect_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://testserver") + client = test_client_factory(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:80") + client = test_client_factory(app, base_url="http://testserver:80") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:443") + client = test_client_factory(app, base_url="http://testserver:443") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:123") + client = test_client_factory(app, base_url="http://testserver:123") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 68cf36df9..314f2be58 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -3,7 +3,6 @@ from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient def view_session(request): @@ -29,10 +28,10 @@ def create_app(): return app -def test_session(): +def test_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/view_session") assert response.json() == {"session": {}} @@ -56,10 +55,10 @@ def test_session(): assert response.json() == {"session": {}} -def test_session_expires(): +def test_session_expires(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -72,11 +71,11 @@ def test_session_expires(): assert response.json() == {"session": {}} -def test_secure_session(): +def test_secure_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", https_only=True) - secure_client = TestClient(app, base_url="https://testserver") - unsecure_client = TestClient(app, base_url="http://testserver") + secure_client = test_client_factory(app, base_url="https://testserver") + unsecure_client = test_client_factory(app, base_url="http://testserver") response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} @@ -103,12 +102,12 @@ def test_secure_session(): assert response.json() == {"session": {}} -def test_session_cookie_subpath(): +def test_session_cookie_subpath(test_client_factory): app = create_app() second_app = create_app() second_app.add_middleware(SessionMiddleware, secret_key="example") app.mount("/second_app", second_app) - client = TestClient(app, base_url="http://testserver/second_app") + client = test_client_factory(app, base_url="http://testserver/second_app") response = client.post("second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 934f2477b..de9c79e66 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_trusted_host_middleware(): +def test_trusted_host_middleware(test_client_factory): app = Starlette() app.add_middleware( @@ -15,15 +14,15 @@ def test_trusted_host_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://subdomain.testserver") + client = test_client_factory(app, base_url="http://subdomain.testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://invalidhost") + client = test_client_factory(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 @@ -34,7 +33,7 @@ def test_default_allowed_hosts(): assert middleware.allowed_hosts == ["*"] -def test_www_redirect(): +def test_www_redirect(test_client_factory): app = Starlette() app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"]) @@ -43,7 +42,7 @@ def test_www_redirect(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://example.com") + client = test_client_factory(app, base_url="https://example.com") response = client.get("/") assert response.status_code == 200 assert response.url == "https://www.example.com/" diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 615805a94..bcb4cd6ff 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -3,7 +3,6 @@ import pytest from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from starlette.testclient import TestClient def hello_world(environ, start_response): @@ -46,41 +45,41 @@ def return_exc_info(environ, start_response): return [output] -def test_wsgi_get(): +def test_wsgi_get(test_client_factory): app = WSGIMiddleware(hello_world) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello World!\n" -def test_wsgi_post(): +def test_wsgi_post(test_client_factory): app = WSGIMiddleware(echo_body) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"example": 123}) assert response.status_code == 200 assert response.text == '{"example": 123}' -def test_wsgi_exception(): +def test_wsgi_exception(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_wsgi_exc_info(): +def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(return_exc_info) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): response = client.get("/") app = WSGIMiddleware(return_exc_info) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" diff --git a/tests/test_applications.py b/tests/test_applications.py index ad8504cbd..6cb490696 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,5 +1,7 @@ import os +import pytest + from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException @@ -7,7 +9,6 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient app = Starlette() @@ -86,14 +87,17 @@ async def websocket_endpoint(session): await session.close() -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client def test_url_path_for(): assert app.url_path_for("func_homepage") == "/func" -def test_func_route(): +def test_func_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" @@ -103,51 +107,51 @@ def test_func_route(): assert response.text == "" -def test_async_route(): +def test_async_route(client): response = client.get("/async") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_class_route(): +def test_class_route(client): response = client.get("/class") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_mounted_route(): +def test_mounted_route(client): response = client.get("/users/") assert response.status_code == 200 assert response.text == "Hello, everyone!" -def test_mounted_route_path_params(): +def test_mounted_route_path_params(client): response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_subdomain_route(): - client = TestClient(app, base_url="https://foo.example.org/") +def test_subdomain_route(test_client_factory): + client = test_client_factory(app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 assert response.text == "Subdomain: foo" -def test_websocket_route(): +def test_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_400(): +def test_400(client): response = client.get("/404") assert response.status_code == 404 assert response.json() == {"detail": "Not Found"} -def test_405(): +def test_405(client): response = client.post("/func") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} @@ -157,15 +161,15 @@ def test_405(): assert response.json() == {"detail": "Custom message"} -def test_500(): - client = TestClient(app, raise_server_exceptions=False) +def test_500(test_client_factory): + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_middleware(): - client = TestClient(app, base_url="http://incorrecthost") +def test_middleware(test_client_factory): + client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" @@ -194,7 +198,7 @@ def test_routes(): ] -def test_app_mount(tmpdir): +def test_app_mount(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -202,7 +206,7 @@ def test_app_mount(tmpdir): app = Starlette() app.mount("/static", StaticFiles(directory=tmpdir)) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/static/example.txt") assert response.status_code == 200 @@ -213,7 +217,7 @@ def test_app_mount(tmpdir): assert response.text == "Method Not Allowed" -def test_app_debug(): +def test_app_debug(test_client_factory): app = Starlette() app.debug = True @@ -221,27 +225,27 @@ def test_app_debug(): async def homepage(request): raise RuntimeError() - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert "RuntimeError" in response.text assert app.debug -def test_app_add_route(): +def test_app_add_route(test_client_factory): app = Starlette() async def homepage(request): return PlainTextResponse("Hello, World!") app.add_route("/", homepage) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_app_add_websocket_route(): +def test_app_add_websocket_route(test_client_factory): app = Starlette() async def websocket_endpoint(session): @@ -250,14 +254,14 @@ async def websocket_endpoint(session): await session.close() app.add_websocket_route("/ws", websocket_endpoint) - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_app_add_event_handler(): +def test_app_add_event_handler(test_client_factory): startup_complete = False cleanup_complete = False app = Starlette() @@ -275,14 +279,14 @@ def run_cleanup(): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete -def test_app_async_lifespan(): +def test_app_async_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -296,14 +300,14 @@ async def lifespan(app): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete -def test_app_sync_lifespan(): +def test_app_sync_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -317,7 +321,7 @@ def lifespan(app): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 8ee87932a..43c7ab96d 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -15,7 +15,6 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect @@ -195,8 +194,8 @@ def foo(): pass # pragma: nocover -def test_user_interface(): - with TestClient(app) as client: +def test_user_interface(test_client_factory): + with test_client_factory(app) as client: response = client.get("/") assert response.status_code == 200 assert response.json() == {"authenticated": False, "user": ""} @@ -206,8 +205,8 @@ def test_user_interface(): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def test_authentication_required(): - with TestClient(app) as client: +def test_authentication_required(test_client_factory): + with test_client_factory(app) as client: response = client.get("/dashboard") assert response.status_code == 403 @@ -258,8 +257,8 @@ def test_authentication_required(): assert response.text == "Invalid basic auth credentials" -def test_websocket_authentication_required(): - with TestClient(app) as client: +def test_websocket_authentication_required(test_client_factory): + with test_client_factory(app) as client: with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws"): pass # pragma: nocover @@ -297,8 +296,8 @@ def test_websocket_authentication_required(): } -def test_authentication_redirect(): - with TestClient(app) as client: +def test_authentication_redirect(test_client_factory): + with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 assert response.url == "http://testserver/" @@ -337,8 +336,8 @@ def control_panel(request): ) -def test_custom_on_error(): - with TestClient(other_app) as client: +def test_custom_on_error(test_client_factory): + with test_client_factory(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} diff --git a/tests/test_background.py b/tests/test_background.py index d9d7ddd87..e299ec362 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,9 +1,8 @@ from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response -from starlette.testclient import TestClient -def test_async_task(): +def test_async_task(test_client_factory): TASK_COMPLETE = False async def async_task(): @@ -16,13 +15,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_sync_task(): +def test_sync_task(test_client_factory): TASK_COMPLETE = False def sync_task(): @@ -35,13 +34,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_multiple_tasks(): +def test_multiple_tasks(test_client_factory): TASK_COUNTER = 0 def increment(amount): @@ -58,7 +57,7 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 diff --git a/tests/test_database.py b/tests/test_database.py index f7280c2c7..1230fc8f6 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,7 +4,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse -from starlette.testclient import TestClient DATABASE_URL = "sqlite:///test.db" @@ -90,8 +89,8 @@ async def read_note_text(request): return JSONResponse(result[0]) -def test_database(): - with TestClient(app) as client: +def test_database(test_client_factory): + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "buy the milk", "completed": True} ) @@ -125,8 +124,8 @@ def test_database(): assert response.json() == "buy the milk" -def test_database_execute_many(): - with TestClient(app) as client: +def test_database_execute_many(test_client_factory): + with test_client_factory(app) as client: response = client.get("/notes") data = [ @@ -144,11 +143,11 @@ def test_database_execute_many(): ] -def test_database_isolated_during_test_cases(): +def test_database_isolated_during_test_cases(test_client_factory): """ Using `TestClient` as a context manager """ - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) @@ -158,7 +157,7 @@ def test_database_isolated_during_test_cases(): assert response.status_code == 200 assert response.json() == [{"text": "just one note", "completed": True}] - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index e491c085f..e57d47486 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -3,7 +3,6 @@ from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint from starlette.responses import PlainTextResponse from starlette.routing import Route, Router -from starlette.testclient import TestClient class Homepage(HTTPEndpoint): @@ -18,46 +17,50 @@ async def get(self, request): routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)] ) -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_http_endpoint_route(): + +def test_http_endpoint_route(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_http_endpoint_route_path_params(): +def test_http_endpoint_route_path_params(client): response = client.get("/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_http_endpoint_route_method(): +def test_http_endpoint_route_method(client): response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_websocket_endpoint_on_connect(): +def test_websocket_endpoint_on_connect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket): assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_endpoint_on_receive_bytes(): +def test_websocket_endpoint_on_receive_bytes(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "bytes" async def on_receive(self, websocket, data): await websocket.send_bytes(b"Message bytes was: " + data) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello, world!") _bytes = websocket.receive_bytes() @@ -68,14 +71,14 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json(): +def test_websocket_endpoint_on_receive_json(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() @@ -86,28 +89,28 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json_binary(): +def test_websocket_endpoint_on_receive_json_binary(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}, mode="binary") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"message": {"hello": "world"}} -def test_websocket_endpoint_on_receive_text(): +def test_websocket_endpoint_on_receive_text(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() @@ -118,26 +121,26 @@ async def on_receive(self, websocket, data): websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(): +def test_websocket_endpoint_on_default(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = None async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" -def test_websocket_endpoint_on_disconnect(): +def test_websocket_endpoint_on_disconnect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_disconnect(self, websocket, close_code): assert close_code == 1001 await websocket.close(code=close_code) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.close(code=1001) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index bab6961b5..5fba9981b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,7 +3,6 @@ from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute -from starlette.testclient import TestClient def raise_runtime_error(request): @@ -37,28 +36,33 @@ async def __call__(self, scope, receive, send): app = ExceptionMiddleware(router) -client = TestClient(app) -def test_not_acceptable(): +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client + + +def test_not_acceptable(client): response = client.get("/not_acceptable") assert response.status_code == 406 assert response.text == "Not Acceptable" -def test_not_modified(): +def test_not_modified(client): response = client.get("/not_modified") assert response.status_code == 304 assert response.text == "" -def test_websockets_should_raise(): +def test_websockets_should_raise(client): with pytest.raises(RuntimeError): with client.websocket_connect("/runtime_error"): pass # pragma: nocover -def test_handled_exc_after_response(): +def test_handled_exc_after_response(test_client_factory, client): # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError): @@ -66,17 +70,17 @@ def test_handled_exc_after_response(): # If `raise_server_exceptions=False` then the test client will still allow # us to see the response as it will have been seen by the client. - allow_200_client = TestClient(app, raise_server_exceptions=False) + allow_200_client = test_client_factory(app, raise_server_exceptions=False) response = allow_200_client.get("/handled_exc_after_response") assert response.status_code == 200 assert response.text == "OK" -def test_force_500_response(): +def test_force_500_response(test_client_factory): def app(scope): raise RuntimeError() - force_500_client = TestClient(app, raise_server_exceptions=False) + force_500_client = test_client_factory(app, raise_server_exceptions=False) response = force_500_client.get("/") assert response.status_code == 500 assert response.text == "" diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 73a720fd1..8a1174e1d 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -3,7 +3,6 @@ from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient class ForceMultipartDict(dict): @@ -70,18 +69,18 @@ async def app_read_body(scope, receive, send): await response(scope, receive, send) -def test_multipart_request_data(tmpdir): - client = TestClient(app) +def test_multipart_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} -def test_multipart_request_files(tmpdir): +def test_multipart_request_files(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { @@ -93,12 +92,12 @@ def test_multipart_request_files(tmpdir): } -def test_multipart_request_files_with_content_type(tmpdir): +def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { @@ -110,7 +109,7 @@ def test_multipart_request_files_with_content_type(tmpdir): } -def test_multipart_request_multiple_files(tmpdir): +def test_multipart_request_multiple_files(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -119,7 +118,7 @@ def test_multipart_request_multiple_files(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} @@ -138,7 +137,7 @@ def test_multipart_request_multiple_files(tmpdir): } -def test_multi_items(tmpdir): +def test_multi_items(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -147,7 +146,7 @@ def test_multi_items(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(multi_items_app) + client = test_client_factory(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", @@ -171,8 +170,8 @@ def test_multi_items(tmpdir): } -def test_multipart_request_mixed_files_and_data(tmpdir): - client = TestClient(app) +def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -208,8 +207,8 @@ def test_multipart_request_mixed_files_and_data(tmpdir): } -def test_multipart_request_with_charset_for_filename(tmpdir): - client = TestClient(app) +def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -236,8 +235,8 @@ def test_multipart_request_with_charset_for_filename(tmpdir): } -def test_multipart_request_without_charset_for_filename(tmpdir): - client = TestClient(app) +def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -263,8 +262,8 @@ def test_multipart_request_without_charset_for_filename(tmpdir): } -def test_multipart_request_with_encoded_value(tmpdir): - client = TestClient(app) +def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -284,38 +283,38 @@ def test_multipart_request_with_encoded_value(tmpdir): assert response.json() == {"value": "Transférer"} -def test_urlencoded_request_data(tmpdir): - client = TestClient(app) +def test_urlencoded_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} -def test_no_request_data(tmpdir): - client = TestClient(app) +def test_no_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/") assert response.json() == {} -def test_urlencoded_percent_encoding(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} -def test_urlencoded_percent_encoding_keys(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} -def test_urlencoded_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} -def test_multipart_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) diff --git a/tests/test_graphql.py b/tests/test_graphql.py index b945a5cfe..8492439f8 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,10 +1,10 @@ import graphene +import pytest from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette from starlette.datastructures import Headers from starlette.graphql import GraphQLApp -from starlette.testclient import TestClient class FakeAuthMiddleware: @@ -33,29 +33,33 @@ def resolve_whoami(self, info): schema = graphene.Schema(query=Query) -app = GraphQLApp(schema=schema, graphiql=True) -client = TestClient(app) -def test_graphql_get(): +@pytest.fixture +def client(test_client_factory): + app = GraphQLApp(schema=schema, graphiql=True) + return test_client_factory(app) + + +def test_graphql_get(client): response = client.get("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post(): +def test_graphql_post(client): response = client.post("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_json(): +def test_graphql_post_json(client): response = client.post("/", json={"query": "{ hello }"}) assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_graphql(): +def test_graphql_post_graphql(client): response = client.post( "/", data="{ hello }", headers={"content-type": "application/graphql"} ) @@ -63,25 +67,25 @@ def test_graphql_post_graphql(): assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_post_invalid_media_type(): +def test_graphql_post_invalid_media_type(client): response = client.post("/", data="{ hello }", headers={"content-type": "dummy"}) assert response.status_code == 415 assert response.text == "Unsupported Media Type" -def test_graphql_put(): +def test_graphql_put(client): response = client.put("/", json={"query": "{ hello }"}) assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_graphql_no_query(): +def test_graphql_no_query(client): response = client.get("/") assert response.status_code == 400 assert response.text == "No GraphQL query found in the request" -def test_graphql_invalid_field(): +def test_graphql_invalid_field(client): response = client.post("/", json={"query": "{ dummy }"}) assert response.status_code == 400 assert response.json() == { @@ -95,34 +99,34 @@ def test_graphql_invalid_field(): } -def test_graphiql_get(): +def test_graphiql_get(client): response = client.get("/", headers={"accept": "text/html"}) assert response.status_code == 200 assert "" in response.text -def test_graphiql_not_found(): +def test_graphiql_not_found(test_client_factory): app = GraphQLApp(schema=schema, graphiql=False) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept": "text/html"}) assert response.status_code == 404 assert response.text == "Not Found" -def test_add_graphql_route(): +def test_add_graphql_route(test_client_factory): app = Starlette() app.add_route("/", GraphQLApp(schema=schema)) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} -def test_graphql_context(): +def test_graphql_context(test_client_factory): app = Starlette() app.add_middleware(FakeAuthMiddleware) app.add_route("/", GraphQLApp(schema=schema)) - client = TestClient(app) + client = test_client_factory(app) response = client.post( "/", json={"query": "{ whoami }"}, headers={"Authorization": "Bearer 123"} ) @@ -141,8 +145,8 @@ async def resolve_hello(self, info, name): async_app = GraphQLApp(schema=async_schema, executor_class=AsyncioExecutor) -def test_graphql_async(no_trio_support): - client = TestClient(async_app) +def test_graphql_async(no_trio_support, test_client_factory): + client = test_client_factory(async_app) response = client.get("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} diff --git a/tests/test_requests.py b/tests/test_requests.py index fee059ab2..d7c69fbeb 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -3,17 +3,16 @@ from starlette.requests import ClientDisconnect, Request, State from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -def test_request_url(): +def test_request_url(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = {"method": request.method, "url": str(request.url)} response = JSONResponse(data) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} @@ -21,26 +20,26 @@ async def app(scope, receive, send): assert response.json() == {"method": "GET", "url": "https://example.org:123/"} -def test_request_query_params(): +def test_request_query_params(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) params = dict(request.query_params) response = JSONResponse({"params": params}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/?a=123&b=456") assert response.json() == {"params": {"a": "123", "b": "456"}} -def test_request_headers(): +def test_request_headers(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) headers = dict(request.headers) response = JSONResponse({"headers": headers}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"host": "example.org"}) assert response.json() == { "headers": { @@ -53,7 +52,7 @@ async def app(scope, receive, send): } -def test_request_client(): +def test_request_client(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = JSONResponse( @@ -61,19 +60,19 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"host": "testclient", "port": 50000} -def test_request_body(): +def test_request_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -85,7 +84,7 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_stream(): +def test_request_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = b"" @@ -94,7 +93,7 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -106,20 +105,20 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_form_urlencoded(): +def test_request_form_urlencoded(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) form = await request.form() response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} -def test_request_body_then_stream(): +def test_request_body_then_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() @@ -129,13 +128,13 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "abc", "stream": "abc"} -def test_request_stream_then_body(): +def test_request_stream_then_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) chunks = b"" @@ -148,20 +147,20 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "", "stream": "abc"} -def test_request_json(): +def test_request_json(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = await request.json() response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": {"a": "123"}} @@ -177,7 +176,7 @@ def test_request_scope_interface(): assert len(request) == 3 -def test_request_without_setting_receive(): +def test_request_without_setting_receive(test_client_factory): """ If Request is instantiated without the receive channel, then .body() is not available. @@ -192,12 +191,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": "Receive channel not available"} -def test_request_disconnect(): +def test_request_disconnect(anyio_backend_name, anyio_backend_options): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -212,10 +211,17 @@ async def receiver(): scope = {"type": "http", "method": "POST", "path": "/"} with pytest.raises(ClientDisconnect): - anyio.run(app, scope, receiver, None) + anyio.run( + app, + scope, + receiver, + None, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) -def test_request_is_disconnected(): +def test_request_is_disconnected(test_client_factory): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -232,7 +238,7 @@ async def app(scope, receive, send): await response(scope, receive, send) disconnected_after_response = await request.is_disconnected() - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"disconnected": False} assert disconnected_after_response @@ -252,19 +258,19 @@ def test_request_state_object(): s.new -def test_request_state(): +def test_request_state(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) request.state.example = 123 response = JSONResponse({"state.example": request.state.example}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"state.example": 123} -def test_request_cookies(): +def test_request_cookies(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) mycookie = request.cookies.get("mycookie") @@ -276,14 +282,14 @@ async def app(scope, receive, send): await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" response = client.get("/") assert response.text == "Hello, cookies!" -def test_cookie_lenient_parsing(): +def test_cookie_lenient_parsing(test_client_factory): """ The following test is based on a cookie set by Okta, a well-known authorization service. It turns out that it's common practice to set cookies that would be @@ -310,7 +316,7 @@ async def app(scope, receive, send): response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": tough_cookie}) result = response.json() assert len(result["cookies"]) == 4 @@ -339,13 +345,13 @@ async def app(scope, receive, send): ("a=b; h=i; a=c", {"a": "c", "h": "i"}), ], ) -def test_cookies_edge_cases(set_cookie, expected): +def test_cookies_edge_cases(set_cookie, expected, test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected @@ -374,7 +380,7 @@ async def app(scope, receive, send): # (" = b ; ; = ; c = ; ", {"": "b", "c": ""}), ], ) -def test_cookies_invalid(set_cookie, expected): +def test_cookies_invalid(set_cookie, expected, test_client_factory): """ Cookie strings that are against the RFC6265 spec but which browsers will send if set via document.cookie. @@ -385,20 +391,20 @@ async def app(scope, receive, send): response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected -def test_chunked_encoding(): +def test_chunked_encoding(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) def post_body(): yield b"foo" @@ -408,7 +414,7 @@ def post_body(): assert response.json() == {"body": "foobar"} -def test_request_send_push_promise(): +def test_request_send_push_promise(test_client_factory): async def app(scope, receive, send): # the server is push-enabled scope["extensions"]["http.response.push"] = {} @@ -419,12 +425,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_push_extension(): +def test_request_send_push_promise_without_push_extension(test_client_factory): """ If server does not support the `http.response.push` extension, .send_push_promise() does nothing. @@ -437,12 +443,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_setting_send(): +def test_request_send_push_promise_without_setting_send(test_client_factory): """ If Request is instantiated without the send channel, then .send_push_promise() is not available. @@ -461,6 +467,6 @@ async def app(scope, receive, send): response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "Send channel not available"} diff --git a/tests/test_responses.py b/tests/test_responses.py index 496e64c86..baba549ba 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -13,40 +13,39 @@ Response, StreamingResponse, ) -from starlette.testclient import TestClient -def test_text_response(): +def test_text_response(test_client_factory): async def app(scope, receive, send): response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hello, world" -def test_bytes_response(): +def test_bytes_response(test_client_factory): async def app(scope, receive, send): response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.content == b"xxxxx" -def test_json_none_response(): +def test_json_none_response(test_client_factory): async def app(scope, receive, send): response = JSONResponse(None) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() is None -def test_redirect_response(): +def test_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") @@ -54,13 +53,13 @@ async def app(scope, receive, send): response = RedirectResponse("/") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" -def test_quoting_redirect_response(): +def test_quoting_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") @@ -68,13 +67,13 @@ async def app(scope, receive, send): response = RedirectResponse("/I ♥ Starlette/") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" -def test_streaming_response(): +def test_streaming_response(test_client_factory): filled_by_bg_task = "" async def app(scope, receive, send): @@ -98,13 +97,13 @@ async def numbers_for_cleanup(start=1, stop=5): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" -def test_streaming_response_custom_iterator(): +def test_streaming_response_custom_iterator(test_client_factory): async def app(scope, receive, send): class CustomAsyncIterator: def __init__(self): @@ -122,12 +121,12 @@ async def __anext__(self): response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "12345" -def test_streaming_response_custom_iterable(): +def test_streaming_response_custom_iterable(test_client_factory): async def app(scope, receive, send): class CustomAsyncIterable: async def __aiter__(self): @@ -137,12 +136,12 @@ async def __aiter__(self): response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "12345" -def test_sync_streaming_response(): +def test_sync_streaming_response(test_client_factory): async def app(scope, receive, send): def numbers(minimum, maximum): for i in range(minimum, maximum + 1): @@ -154,37 +153,37 @@ def numbers(minimum, maximum): response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" -def test_response_headers(): +def test_response_headers(test_client_factory): async def app(scope, receive, send): headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" -def test_response_phrase(): +def test_response_phrase(test_client_factory): app = Response(status_code=204) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "No Content" app = Response(b"", status_code=123) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "" -def test_file_response(tmpdir): +def test_file_response(tmpdir, test_client_factory): path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -213,7 +212,7 @@ async def app(scope, receive, send): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK @@ -226,31 +225,31 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" -def test_file_response_with_directory_raises_error(tmpdir): +def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): app = FileResponse(path=tmpdir, filename="example.png") - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir): +def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmpdir): +def test_file_response_with_chinese_filename(tmpdir, test_client_factory): content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = os.path.join(tmpdir, filename) with open(path, "wb") as f: f.write(content) app = FileResponse(path=path, filename=filename) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" assert response.status_code == status.HTTP_200_OK @@ -258,7 +257,7 @@ def test_file_response_with_chinese_filename(tmpdir): assert response.headers["content-disposition"] == expected_disposition -def test_set_cookie(): +def test_set_cookie(test_client_factory): async def app(scope, receive, send): response = Response("Hello, world!", media_type="text/plain") response.set_cookie( @@ -274,12 +273,12 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_delete_cookie(): +def test_delete_cookie(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") @@ -289,24 +288,24 @@ async def app(scope, receive, send): response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") -def test_populate_headers(): +def test_populate_headers(test_client_factory): app = Response(content="hi", headers={}, media_type="text/html") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_head_method(): +def test_head_method(test_client_factory): app = Response("hello, world", media_type="text/plain") - client = TestClient(app) + client = test_client_factory(app) response = client.head("/") assert response.text == "" diff --git a/tests/test_routing.py b/tests/test_routing.py index 1d8eb8d95..3c096125f 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -6,7 +6,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -105,10 +104,13 @@ async def websocket_params(session): await session.close() -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_router(): +def test_router(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" @@ -147,7 +149,7 @@ def test_router(): assert response.text == "xxxxx" -def test_route_converters(): +def test_route_converters(client): # Test integer conversion response = client.get("/int/5") assert response.status_code == 200 @@ -232,19 +234,19 @@ def test_url_for(): ) -def test_router_add_route(): +def test_router_add_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_router_duplicate_path(): +def test_router_duplicate_path(client): response = client.post("/func") assert response.status_code == 200 assert response.text == "Hello, POST!" -def test_router_add_websocket_route(): +def test_router_add_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" @@ -275,8 +277,8 @@ async def __call__(self, scope, receive, send): ) -def test_protocol_switch(): - client = TestClient(mixed_protocol_app) +def test_protocol_switch(test_client_factory): + client = test_client_factory(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 @@ -293,9 +295,9 @@ def test_protocol_switch(): ok = PlainTextResponse("OK") -def test_mount_urls(): +def test_mount_urls(test_client_factory): mounted = Router([Mount("/users", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/users").status_code == 200 assert client.get("/users").url == "http://testserver/users/" assert client.get("/users/").status_code == 200 @@ -318,9 +320,9 @@ def test_reverse_mount_urls(): ) -def test_mount_at_root(): +def test_mount_at_root(test_client_factory): mounted = Router([Mount("/", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/").status_code == 200 @@ -348,8 +350,8 @@ def users_api(request): ) -def test_host_routing(): - client = TestClient(mixed_hosts_app, base_url="https://api.example.org/") +def test_host_routing(test_client_factory): + client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -358,7 +360,7 @@ def test_host_routing(): response = client.get("/") assert response.status_code == 404 - client = TestClient(mixed_hosts_app, base_url="https://www.example.org/") + client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -393,8 +395,8 @@ async def subdomain_app(scope, receive, send): ) -def test_subdomain_routing(): - client = TestClient(subdomain_app, base_url="https://foo.example.org/") +def test_subdomain_routing(test_client_factory): + client = test_client_factory(subdomain_app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -429,9 +431,11 @@ async def echo_urls(request): ] -def test_url_for_with_root_path(): +def test_url_for_with_root_path(test_client_factory): app = Starlette(routes=echo_url_routes) - client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path") + client = test_client_factory( + app, base_url="https://www.example.org/", root_path="/sub_path" + ) response = client.get("/") assert response.json() == { "index": "https://www.example.org/sub_path/", @@ -459,17 +463,17 @@ def test_url_for_with_double_mount(): assert url == "/mount/static/123" -def test_standalone_route_matches(): +def test_standalone_route_matches(test_client_factory): app = Route("/", PlainTextResponse("Hello, World!")) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_standalone_route_does_not_match(): +def test_standalone_route_does_not_match(test_client_factory): app = Route("/", PlainTextResponse("Hello, World!")) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/invalid") assert response.status_code == 404 assert response.text == "Not Found" @@ -481,23 +485,23 @@ async def ws_helloworld(websocket): await websocket.close() -def test_standalone_ws_route_matches(): +def test_standalone_ws_route_matches(test_client_factory): app = WebSocketRoute("/", ws_helloworld) - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: text = websocket.receive_text() assert text == "Hello, world!" -def test_standalone_ws_route_does_not_match(): +def test_standalone_ws_route_does_not_match(test_client_factory): app = WebSocketRoute("/", ws_helloworld) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/invalid"): pass # pragma: nocover -def test_lifespan_async(): +def test_lifespan_async(test_client_factory): startup_complete = False shutdown_complete = False @@ -520,7 +524,7 @@ async def run_shutdown(): assert not startup_complete assert not shutdown_complete - with TestClient(app) as client: + with test_client_factory(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -528,7 +532,7 @@ async def run_shutdown(): assert shutdown_complete -def test_lifespan_sync(): +def test_lifespan_sync(test_client_factory): startup_complete = False shutdown_complete = False @@ -551,7 +555,7 @@ def run_shutdown(): assert not startup_complete assert not shutdown_complete - with TestClient(app) as client: + with test_client_factory(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -559,7 +563,7 @@ def run_shutdown(): assert shutdown_complete -def test_raise_on_startup(): +def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() @@ -576,19 +580,19 @@ async def _send(message): startup_failed = False with pytest.raises(RuntimeError): - with TestClient(app): + with test_client_factory(app): pass # pragma: nocover assert startup_failed -def test_raise_on_shutdown(): +def test_raise_on_shutdown(test_client_factory): def run_shutdown(): raise RuntimeError() app = Router(on_shutdown=[run_shutdown]) with pytest.raises(RuntimeError): - with TestClient(app): + with test_client_factory(app): pass # pragma: nocover @@ -615,8 +619,8 @@ async def _partial_async_endpoint(arg, request): ) -def test_partial_async_endpoint(): - test_client = TestClient(partial_async_app) +def test_partial_async_endpoint(test_client_factory): + test_client = test_client_factory(partial_async_app) response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"arg": "foo"} diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 0ae43238f..28fe777f0 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,7 +1,6 @@ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.schemas import SchemaGenerator -from starlette.testclient import TestClient schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} @@ -213,8 +212,8 @@ def test_schema_generation(): """ -def test_schema_endpoint(): - client = TestClient(app) +def test_schema_endpoint(test_client_factory): + client = test_client_factory(app) response = client.get("/schema") assert response.headers["Content-Type"] == "application/vnd.oai.openapi" assert response.text.strip() == EXPECTED_SCHEMA.strip() diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 3c8ff240e..d5ec1afc5 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -9,35 +9,34 @@ from starlette.requests import Request from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient -def test_staticfiles(tmpdir): +def test_staticfiles(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_with_pathlib(tmpdir): +def test_staticfiles_with_pathlib(tmpdir, test_client_factory): base_dir = pathlib.Path(tmpdir) path = base_dir / "example.txt" with open(path, "w") as file: file.write("") app = StaticFiles(directory=base_dir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_head_with_middleware(tmpdir): +def test_staticfiles_head_with_middleware(tmpdir, test_client_factory): """ see https://github.com/encode/starlette/pull/935 """ @@ -53,51 +52,51 @@ async def does_nothing_middleware(request: Request, call_next): response = await call_next(request) return response - client = TestClient(app) + client = test_client_factory(app) response = client.head("/static/example.txt") assert response.status_code == 200 assert response.headers.get("content-length") == "100" -def test_staticfiles_with_package(): +def test_staticfiles_with_package(test_client_factory): app = StaticFiles(packages=["tests"]) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" -def test_staticfiles_post(tmpdir): +def test_staticfiles_post(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_staticfiles_with_directory_returns_404(tmpdir): +def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_with_missing_file_returns_404(tmpdir): +def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/404.txt") assert response.status_code == 404 assert response.text == "Not Found" @@ -110,30 +109,32 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir): assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_missing_directory(tmpdir): +def test_staticfiles_configured_with_missing_directory(tmpdir, test_client_factory): path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_file_instead_of_directory(tmpdir): +def test_staticfiles_configured_with_file_instead_of_directory( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "is not a directory" in str(exc_info.value) -def test_staticfiles_config_check_occurs_only_once(tmpdir): +def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory): app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) assert not app.config_checked client.get("/") assert app.config_checked @@ -158,26 +159,26 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): assert response.body == b"Not Found" -def test_staticfiles_never_read_file_for_head_method(tmpdir): +def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.head("/example.txt") assert response.status_code == 200 assert response.content == b"" assert response.headers["content-length"] == "14" -def test_staticfiles_304_with_etag_match(tmpdir): +def test_staticfiles_304_with_etag_match(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 last_etag = first_resp.headers["etag"] @@ -186,7 +187,9 @@ def test_staticfiles_304_with_etag_match(tmpdir): assert second_resp.content == b"" -def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): +def test_staticfiles_304_with_last_modified_compare_last_req( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") file_last_modified_time = time.mktime( time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") @@ -196,7 +199,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): os.utime(path, (file_last_modified_time, file_last_modified_time)) app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) # last modified less than last request, 304 response = client.get( "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"} @@ -211,7 +214,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): assert response.content == b"" -def test_staticfiles_html(tmpdir): +def test_staticfiles_html(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -222,7 +225,7 @@ def test_staticfiles_html(tmpdir): file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -244,7 +247,9 @@ def test_staticfiles_html(tmpdir): assert response.text == "

Custom not found page

" -def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir): +def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( + tmpdir, test_client_factory +): path_404 = os.path.join(tmpdir, "404.html") with open(path_404, "w") as file: file.write("

404 file

") @@ -259,7 +264,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir): os.utime(path_some, (common_modified_time, common_modified_time)) app = StaticFiles(directory=tmpdir, html=True) - client = TestClient(app) + client = test_client_factory(app) resp_exists = client.get("/some.html") assert resp_exists.status_code == 200 diff --git a/tests/test_templates.py b/tests/test_templates.py index a0ab3e1b0..073482d65 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,10 +4,9 @@ from starlette.applications import Starlette from starlette.templating import Jinja2Templates -from starlette.testclient import TestClient -def test_templates(tmpdir): +def test_templates(tmpdir, test_client_factory): path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -19,7 +18,7 @@ def test_templates(tmpdir): async def homepage(request): return templates.TemplateResponse("index.html", {"request": request}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 44e3320a4..fd96f69a7 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -4,7 +4,6 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect mock_service = Starlette() @@ -15,14 +14,16 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) -app = Starlette() +def create_app(test_client_factory): + app = Starlette() + @app.route("/") + def homepage(request): + client = test_client_factory(mock_service) + response = client.get("/") + return JSONResponse(response.json()) -@app.route("/") -def homepage(request): - client = TestClient(mock_service) - response = client.get("/") - return JSONResponse(response.json()) + return app startup_error_app = Starlette() @@ -33,30 +34,30 @@ def startup(): raise RuntimeError() -def test_use_testclient_in_endpoint(): +def test_use_testclient_in_endpoint(test_client_factory): """ We should be able to use the test client within applications. This is useful if we need to mock out other services, during tests or in development. """ - client = TestClient(app) + client = test_client_factory(create_app(test_client_factory)) response = client.get("/") assert response.json() == {"mock": "example"} -def test_use_testclient_as_contextmanager(): - with TestClient(app): +def test_use_testclient_as_contextmanager(test_client_factory): + with test_client_factory(create_app(test_client_factory)): pass -def test_error_on_startup(): +def test_error_on_startup(test_client_factory): with pytest.raises(RuntimeError): - with TestClient(startup_error_app): + with test_client_factory(startup_error_app): pass # pragma: no cover -def test_exception_in_middleware(): +def test_exception_in_middleware(test_client_factory): class MiddlewareException(Exception): pass @@ -70,11 +71,11 @@ async def __call__(self, scope, receive, send): broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) with pytest.raises(MiddlewareException): - with TestClient(broken_middleware): + with test_client_factory(broken_middleware): pass # pragma: no cover -def test_testclient_asgi2(): +def test_testclient_asgi2(test_client_factory): def app(scope): async def inner(receive, send): await send( @@ -88,12 +89,12 @@ async def inner(receive, send): return inner - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_testclient_asgi3(): +def test_testclient_asgi3(test_client_factory): async def app(scope, receive, send): await send( { @@ -104,12 +105,12 @@ async def app(scope, receive, send): ) await send({"type": "http.response.body", "body": b"Hello, world!"}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_websocket_blocking_receive(): +def test_websocket_blocking_receive(test_client_factory): def app(scope): async def respond(websocket): await websocket.send_json({"message": "test"}) @@ -128,26 +129,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} - - -def test_backend_name(request): - """ - Test that the tests are defaulting to the correct backend and that a new - instance of TestClient can be created using different backend options. - """ - # client created using monkeypatched async_backend - client1 = TestClient(mock_service) - if "trio" in request.keywords: - client2 = TestClient(mock_service, backend="asyncio") - assert client1.async_backend["backend"] == "trio" - assert client2.async_backend["backend"] == "asyncio" - elif "asyncio" in request.keywords: - client2 = TestClient(mock_service, backend="trio") - assert client1.async_backend["backend"] == "asyncio" - assert client2.async_backend["backend"] == "trio" - else: - pytest.fail("Unknown backend") # pragma: nocover diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 63ecd050a..bb073a011 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -2,11 +2,10 @@ import pytest from starlette import status -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect -def test_websocket_url(): +def test_websocket_url(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -16,13 +15,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(): +def test_websocket_binary_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -33,14 +32,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} -def test_websocket_query_params(): +def test_websocket_query_params(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -51,13 +50,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} -def test_websocket_headers(): +def test_websocket_headers(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -68,7 +67,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", @@ -83,7 +82,7 @@ async def asgi(receive, send): assert data == {"headers": expected_headers} -def test_websocket_port(): +def test_websocket_port(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -93,13 +92,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} -def test_websocket_send_and_receive_text(): +def test_websocket_send_and_receive_text(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -110,14 +109,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_send_and_receive_bytes(): +def test_websocket_send_and_receive_bytes(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -128,14 +127,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_send_and_receive_json(): +def test_websocket_send_and_receive_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -146,14 +145,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_iter_text(): +def test_websocket_iter_text(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -163,14 +162,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_iter_bytes(): +def test_websocket_iter_bytes(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -180,14 +179,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_iter_json(): +def test_websocket_iter_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -197,14 +196,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(): +def test_websocket_concurrency_pattern(test_client_factory): def app(scope): stream_send, stream_receive = anyio.create_memory_object_stream() @@ -228,14 +227,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"hello": "world"} -def test_client_close(): +def test_client_close(test_client_factory): close_code = None def app(scope): @@ -250,13 +249,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY) assert close_code == status.WS_1001_GOING_AWAY -def test_application_close(): +def test_application_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -265,14 +264,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(): +def test_rejected_connection(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -280,14 +279,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY -def test_subprotocol(): +def test_subprotocol(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -297,25 +296,25 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_exception(): +def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send): assert False return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): pass # pragma: nocover -def test_duplicate_close(): +def test_duplicate_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -325,13 +324,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover -def test_duplicate_disconnect(): +def test_duplicate_disconnect(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -342,7 +341,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() From 3498c9e9610f7579b3d82da63fad86833d52f7bc Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 23 Jun 2021 21:14:39 +0100 Subject: [PATCH 3/5] remove monkeypatching TestClient interface --- setup.py | 5 ++++- starlette/testclient.py | 32 ++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 978a606c7..ac6479746 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,10 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0,<4"], + install_requires=[ + "anyio>=3.0.0,<4", + "typing_extensions; python_version < '3.8'", + ], extras_require={ "full": [ "graphene; python_version<'3.10'", diff --git a/starlette/testclient.py b/starlette/testclient.py index 7201809e2..21f4dba90 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -6,6 +6,7 @@ import json import math import queue +import sys import types import typing from concurrent.futures import Future @@ -18,6 +19,11 @@ from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect +if sys.version_info >= (3, 8): # pragma: no cover + from typing import TypedDict # pragma: no cover +else: # pragma: no cover + from typing_extensions import TypedDict # pragma: no cover + # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar @@ -91,11 +97,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +class _AsyncBackend(TypedDict): + backend: str + backend_options: typing.Dict[str, typing.Any] + + class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, - async_backend: typing.Dict[str, typing.Any], + async_backend: _AsyncBackend, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: @@ -271,7 +282,10 @@ async def send(message: Message) -> None: class WebSocketTestSession: def __init__( - self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any] + self, + app: ASGI3App, + scope: Scope, + async_backend: _AsyncBackend, ) -> None: self.app = app self.scope = scope @@ -381,11 +395,6 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. - #: These are the default options for the constructor arguments - async_backend: typing.Dict[str, typing.Any] = { - "backend": "asyncio", - "backend_options": {}, - } task: "Future[None]" def __init__( @@ -394,14 +403,13 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: typing.Optional[str] = None, + backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() - self.async_backend = { - "backend": backend or self.async_backend["backend"], - "backend_options": backend_options or self.async_backend["backend_options"], - } + self.async_backend = _AsyncBackend( + backend=backend, backend_options=backend_options or {} + ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app From 731d2bcc3574ec6a963eb8541fcad9b73dde771f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 28 Jun 2021 21:18:30 +0100 Subject: [PATCH 4/5] document where anyio_backend_name comes from --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index acaea4c87..bb68aa5e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ def no_trio_support(anyio_backend_name): @pytest.fixture def test_client_factory(anyio_backend_name, anyio_backend_options): + # anyio_backend_name defined by: + # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on return functools.partial( TestClient, backend=anyio_backend_name, From 581c95a7e1c750d9ba86b64dbd52d21b201bb49d Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 28 Jun 2021 21:28:02 +0100 Subject: [PATCH 5/5] Update starlette/testclient.py Co-authored-by: Jamie Hewland --- starlette/testclient.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 21f4dba90..33bb410d0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -20,9 +20,9 @@ from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 8): # pragma: no cover - from typing import TypedDict # pragma: no cover + from typing import TypedDict else: # pragma: no cover - from typing_extensions import TypedDict # pragma: no cover + from typing_extensions import TypedDict # Annotations for `Session.request()` Cookies = typing.Union[