Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create types module inside tests #2502

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
7 changes: 3 additions & 4 deletions tests/conftest.py
@@ -1,13 +1,12 @@
from __future__ import annotations

import functools
from typing import Any, Callable, Literal
from typing import Any, Literal

import pytest

from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


@pytest.fixture
Expand All @@ -21,4 +20,4 @@ def test_client_factory(
TestClient,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
) # type: ignore
4 changes: 1 addition & 3 deletions tests/middleware/test_base.py
Expand Up @@ -5,7 +5,6 @@
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
)

Expand All @@ -23,8 +22,7 @@
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.types import TestClientFactory


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down
7 changes: 1 addition & 6 deletions tests/middleware/test_cors.py
@@ -1,15 +1,10 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import ASGIApp

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.types import TestClientFactory


def test_cors_allow_all(
Expand Down
6 changes: 2 additions & 4 deletions tests/middleware/test_errors.py
@@ -1,4 +1,4 @@
from typing import Any, Callable
from typing import Any

import pytest

Expand All @@ -8,10 +8,8 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_handler(
Expand Down
11 changes: 4 additions & 7 deletions tests/middleware/test_gzip.py
@@ -1,15 +1,10 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import Request
from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import ASGIApp

TestClientFactory = Callable[[ASGIApp], TestClient]
from tests.types import TestClientFactory


def test_gzip_responses(test_client_factory: TestClientFactory) -> None:
Expand All @@ -29,7 +24,9 @@ def homepage(request: Request) -> PlainTextResponse:
assert int(response.headers["Content-Length"]) < 4000


def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None:
def test_gzip_not_in_accept_encoding(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("x" * 4000, status_code=200)

Expand Down
6 changes: 1 addition & 5 deletions tests/middleware/test_https_redirect.py
@@ -1,14 +1,10 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None:
Expand Down
4 changes: 1 addition & 3 deletions tests/middleware/test_session.py
@@ -1,5 +1,4 @@
import re
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
Expand All @@ -8,8 +7,7 @@
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def view_session(request: Request) -> JSONResponse:
Expand Down
6 changes: 1 addition & 5 deletions tests/middleware/test_trusted_host.py
@@ -1,14 +1,10 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None:
Expand Down
3 changes: 1 addition & 2 deletions tests/middleware/test_wsgi.py
Expand Up @@ -5,10 +5,9 @@

from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from starlette.testclient import TestClient
from tests.types import TestClientFactory

WSGIResponse = Iterable[bytes]
TestClientFactory = Callable[..., TestClient]
StartResponse = Callable[..., Any]
Environment = Dict[str, Any]

Expand Down
9 changes: 5 additions & 4 deletions tests/test_applications.py
@@ -1,7 +1,7 @@
import os
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Callable, Generator
from typing import AsyncGenerator, AsyncIterator, Generator

import anyio
import pytest
Expand All @@ -20,8 +20,7 @@
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
Expand Down Expand Up @@ -132,7 +131,9 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) ->


@pytest.fixture
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
def client(
test_client_factory: TestClientFactory,
) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client

Expand Down
3 changes: 1 addition & 2 deletions tests/test_authentication.py
Expand Up @@ -21,10 +21,9 @@
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect
from tests.types import TestClientFactory

TestClientFactory = Callable[..., TestClient]
AsyncEndpoint = Callable[..., Awaitable[Response]]
SyncEndpoint = Callable[..., Response]

Expand Down
6 changes: 1 addition & 5 deletions tests/test_background.py
@@ -1,13 +1,9 @@
from typing import Callable

import pytest

from starlette.background import BackgroundTask, BackgroundTasks
from starlette.responses import Response
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_async_task(test_client_factory: TestClientFactory) -> None:
Expand Down
6 changes: 2 additions & 4 deletions tests/test_concurrency.py
@@ -1,5 +1,5 @@
from contextvars import ContextVar
from typing import Callable, Iterator
from typing import Iterator

import anyio
import pytest
Expand All @@ -9,9 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


@pytest.mark.anyio
Expand Down
6 changes: 2 additions & 4 deletions tests/test_convertors.py
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Iterator
from typing import Iterator

import pytest

Expand All @@ -8,9 +8,7 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route, Router
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


@pytest.fixture(scope="module", autouse=True)
Expand Down
13 changes: 8 additions & 5 deletions tests/test_endpoints.py
@@ -1,4 +1,4 @@
from typing import Callable, Iterator
from typing import Iterator

import pytest

Expand All @@ -8,8 +8,7 @@
from starlette.routing import Route, Router
from starlette.testclient import TestClient
from starlette.websockets import WebSocket

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


class Homepage(HTTPEndpoint):
Expand Down Expand Up @@ -50,7 +49,9 @@ def test_http_endpoint_route_method(client: TestClient) -> None:
assert response.headers["allow"] == "GET"


def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None:
def test_websocket_endpoint_on_connect(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
async def on_connect(self, websocket: WebSocket) -> None:
assert websocket["subprotocols"] == ["soap", "wamp"]
Expand Down Expand Up @@ -137,7 +138,9 @@ async def on_receive(self, websocket: WebSocket, data: str) -> None:
websocket.send_bytes(b"Hello world")


def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None:
def test_websocket_endpoint_on_default(
test_client_factory: TestClientFactory,
) -> None:
class WebSocketApp(WebSocketEndpoint):
encoding = None

Expand Down
9 changes: 5 additions & 4 deletions tests/test_exceptions.py
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Generator
from typing import Generator

import pytest

Expand All @@ -10,8 +10,7 @@
from starlette.routing import Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def raise_runtime_error(request: Request) -> None:
Expand Down Expand Up @@ -82,7 +81,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


@pytest.fixture
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
def client(
test_client_factory: TestClientFactory,
) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client

Expand Down
4 changes: 1 addition & 3 deletions tests/test_formparsers.py
Expand Up @@ -13,10 +13,8 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send

TestClientFactory = typing.Callable[..., TestClient]
from tests.types import TestClientFactory


class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_requests.py
@@ -1,18 +1,16 @@
from __future__ import annotations

import sys
from typing import Any, Callable, Iterator
from typing import Any, Iterator

import anyio
import pytest

from starlette.datastructures import Address, State
from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.testclient import TestClient
from starlette.types import Message, Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_request_url(test_client_factory: TestClientFactory) -> None:
Expand Down Expand Up @@ -133,7 +131,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None:
def test_request_form_context_manager(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
async with request.form() as form:
Expand Down
9 changes: 5 additions & 4 deletions tests/test_responses.py
Expand Up @@ -5,7 +5,7 @@
import time
from http.cookies import SimpleCookie
from pathlib import Path
from typing import AsyncIterator, Callable, Iterator
from typing import AsyncIterator, Iterator

import anyio
import pytest
Expand All @@ -23,8 +23,7 @@
)
from starlette.testclient import TestClient
from starlette.types import Message, Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]
from tests.types import TestClientFactory


def test_text_response(test_client_factory: TestClientFactory) -> None:
Expand Down Expand Up @@ -532,7 +531,9 @@ def test_streaming_response_unknown_size(
assert "content-length" not in response.headers


def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
def test_streaming_response_known_size(
test_client_factory: TestClientFactory,
) -> None:
app = StreamingResponse(
content=iter(["hello", "world"]), headers={"content-length": "10"}
)
Expand Down