Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create types module inside tests #2502

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
7 changes: 3 additions & 4 deletions tests/conftest.py
@@ -1,18 +1,17 @@
import functools
from typing import Any, Callable, Dict, Literal
from typing import Any, Dict, Literal

import pytest

from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import ClientFactoryProtocol


@pytest.fixture
def test_client_factory(
anyio_backend_name: Literal["asyncio", "trio"],
anyio_backend_options: Dict[str, Any],
) -> TestClientFactory:
) -> ClientFactoryProtocol:
# anyio_backend_name defined by:
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
return functools.partial(
Expand Down
36 changes: 17 additions & 19 deletions tests/middleware/test_base.py
Expand Up @@ -3,7 +3,6 @@
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
List,
Type,
Expand All @@ -24,8 +23,7 @@
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.types import ClientFactoryProtocol


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -90,7 +88,7 @@ async def websocket_endpoint(session: WebSocket) -> None:
)


def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
def test_custom_middleware(test_client_factory: ClientFactoryProtocol) -> None:
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
Expand All @@ -112,7 +110,7 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:


def test_state_data_across_multiple_middlewares(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
expected_value1 = "foo"
expected_value2 = "bar"
Expand Down Expand Up @@ -167,7 +165,7 @@ def homepage(request: Request) -> PlainTextResponse:
assert response.headers["X-State-Bar"] == expected_value2


def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None:
def test_app_middleware_argument(test_client_factory: ClientFactoryProtocol) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")

Expand All @@ -180,7 +178,7 @@ def homepage(request: Request) -> PlainTextResponse:
assert response.headers["Custom-Header"] == "Example"


def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None:
def test_fully_evaluated_response(test_client_factory: ClientFactoryProtocol) -> None:
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand Down Expand Up @@ -240,7 +238,7 @@ async def dispatch(
],
)
def test_contextvars(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
middleware_cls: Type[_MiddlewareClass[Any]],
) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
Expand Down Expand Up @@ -446,7 +444,7 @@ async def send(message: Message) -> None:


def test_app_receives_http_disconnect_while_sending_if_discarded(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand Down Expand Up @@ -524,7 +522,7 @@ async def cancel_on_disconnect(


def test_app_receives_http_disconnect_after_sending_if_discarded(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand Down Expand Up @@ -574,7 +572,7 @@ async def downstream_app(


def test_read_request_stream_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b""]
Expand Down Expand Up @@ -606,7 +604,7 @@ async def dispatch(


def test_read_request_stream_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
Expand Down Expand Up @@ -635,7 +633,7 @@ async def dispatch(


def test_read_request_body_in_app_after_middleware_calls_stream(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b""
Expand Down Expand Up @@ -664,7 +662,7 @@ async def dispatch(


def test_read_request_body_in_app_after_middleware_calls_body(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
Expand All @@ -690,7 +688,7 @@ async def dispatch(


def test_read_request_stream_in_dispatch_after_app_calls_stream(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
expected = [b"a", b""]
Expand Down Expand Up @@ -722,7 +720,7 @@ async def dispatch(


def test_read_request_stream_in_dispatch_after_app_calls_body(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
Expand Down Expand Up @@ -808,7 +806,7 @@ async def send(msg: Message) -> None:


def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
Expand Down Expand Up @@ -840,7 +838,7 @@ async def dispatch(


def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
assert await request.body() == b"a"
Expand Down Expand Up @@ -956,7 +954,7 @@ async def send(msg: Message) -> None:


def test_downstream_middleware_modifies_receive(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
"""If a downstream middleware modifies receive() the final ASGI app
should see the modified version.
Expand Down
37 changes: 16 additions & 21 deletions tests/middleware/test_cors.py
@@ -1,19 +1,14 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import ASGIApp

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.types import ClientFactoryProtocol


def test_cors_allow_all(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -74,7 +69,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_allow_all_except_credentials(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -125,7 +120,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_allow_specific_origin(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -174,7 +169,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_disallowed_preflight(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover
Expand Down Expand Up @@ -215,7 +210,7 @@ def homepage(request: Request) -> None:


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> None:
return # pragma: no cover
Expand Down Expand Up @@ -250,7 +245,7 @@ def homepage(request: Request) -> None:


def test_cors_preflight_allow_all_methods(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover
Expand All @@ -276,7 +271,7 @@ def homepage(request: Request) -> None:


def test_cors_allow_all_methods(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -307,7 +302,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_allow_origin_regex(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -379,7 +374,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_allow_origin_regex_fullmatch(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down Expand Up @@ -417,7 +412,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_credentialed_requests_return_specific_origin(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand All @@ -438,7 +433,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_vary_header_defaults_to_origin(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand All @@ -458,7 +453,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_vary_header_is_not_set_for_non_credentialed_request(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
Expand All @@ -477,7 +472,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_vary_header_is_properly_set_for_credentialed_request(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
Expand All @@ -498,7 +493,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
Expand All @@ -519,7 +514,7 @@ def homepage(request: Request) -> PlainTextResponse:


def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)
Expand Down
18 changes: 8 additions & 10 deletions tests/middleware/test_errors.py
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any

import pytest

Expand All @@ -8,14 +8,12 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import ClientFactoryProtocol


def test_handler(
test_client_factory: TestClientFactory,
test_client_factory: ClientFactoryProtocol,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
Expand All @@ -30,7 +28,7 @@ def error_500(request: Request, exc: Exception) -> JSONResponse:
assert response.json() == {"detail": "Server Error"}


def test_debug_text(test_client_factory: TestClientFactory) -> None:
def test_debug_text(test_client_factory: ClientFactoryProtocol) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

Expand All @@ -42,7 +40,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert "RuntimeError: Something went wrong" in response.text


def test_debug_html(test_client_factory: TestClientFactory) -> None:
def test_debug_html(test_client_factory: ClientFactoryProtocol) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

Expand All @@ -54,7 +52,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert "RuntimeError" in response.text


def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
def test_debug_after_response_sent(test_client_factory: ClientFactoryProtocol) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"", status_code=204)
await response(scope, receive, send)
Expand All @@ -66,7 +64,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
client.get("/")


def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
def test_debug_not_http(test_client_factory: ClientFactoryProtocol) -> None:
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""
Expand All @@ -82,7 +80,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
pass # pragma: nocover


def test_background_task(test_client_factory: TestClientFactory) -> None:
def test_background_task(test_client_factory: ClientFactoryProtocol) -> None:
accessed_error_handler = False

def error_handler(request: Request, exc: Exception) -> Any:
Expand Down