From 275b3e3f229ae223e7731030648e3971bd2fdc4a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 5 Oct 2022 22:34:37 +0200 Subject: [PATCH] Replace deprecation by removal --- setup.cfg | 2 - starlette/applications.py | 160 ------------- starlette/middleware/base.py | 122 ---------- starlette/routing.py | 116 ---------- starlette/testclient.py | 88 +++---- tests/middleware/test_base.py | 415 ---------------------------------- tests/test_responses.py | 2 +- tests/test_routing.py | 18 -- tests/test_staticfiles.py | 25 -- 9 files changed, 33 insertions(+), 915 deletions(-) delete mode 100644 starlette/middleware/base.py delete mode 100644 tests/middleware/test_base.py diff --git a/setup.cfg b/setup.cfg index c52cbed097..840eb55450 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,9 +32,7 @@ filterwarnings= ignore: Async generator 'starlette\.requests\.Request\.stream' was garbage collected before it had been exhausted.*:ResourceWarning ignore: path is deprecated.*:DeprecationWarning:certifi ignore: Use 'content=<...>' to upload raw bytes/text content.:DeprecationWarning - ignore: The `allow_redirects` argument is deprecated. Use `follow_redirects` instead.:DeprecationWarning ignore: 'cgi' is deprecated and slated for removal in Python 3\.13:DeprecationWarning - ignore: The 'BaseHTTPMiddleware' is deprecated, and will be removed in version 2\.0\.0.*:DeprecationWarning [coverage:run] source_pkgs = starlette, tests diff --git a/starlette/applications.py b/starlette/applications.py index c3daade5cf..5a8f69a5ab 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -2,7 +2,6 @@ from starlette.datastructures import State, URLPath from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request @@ -122,162 +121,3 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self await self.middleware_stack(scope, receive, send) - - # The following usages are now discouraged in favour of configuration - # during Starlette.__init__(...) - def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover - return self.router.on_event(event_type) - - def mount( - self, path: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: nocover - """ - We no longer document this API, and its usage is discouraged. - Instead you should use the following approach: - - routes = [ - Mount(path, ...), - ... - ] - - app = Starlette(routes=routes) - """ - - self.router.mount(path, app=app, name=name) - - def host( - self, host: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: no cover - """ - We no longer document this API, and its usage is discouraged. - Instead you should use the following approach: - - routes = [ - Host(path, ...), - ... - ] - - app = Starlette(routes=routes) - """ - - self.router.host(host, app=app, name=name) - - def add_middleware( - self, middleware_class: type, **options: typing.Any - ) -> None: # pragma: no cover - self.user_middleware.insert(0, Middleware(middleware_class, **options)) - self.middleware_stack = self.build_middleware_stack() - - def add_exception_handler( - self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], - handler: typing.Callable, - ) -> None: # pragma: no cover - self.exception_handlers[exc_class_or_status_code] = handler - self.middleware_stack = self.build_middleware_stack() - - def add_event_handler( - self, event_type: str, func: typing.Callable - ) -> None: # pragma: no cover - self.router.add_event_handler(event_type, func) - - def add_route( - self, - path: str, - route: typing.Callable, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, - include_in_schema: bool = True, - ) -> None: # pragma: no cover - self.router.add_route( - path, route, methods=methods, name=name, include_in_schema=include_in_schema - ) - - def add_websocket_route( - self, path: str, route: typing.Callable, name: typing.Optional[str] = None - ) -> None: # pragma: no cover - self.router.add_websocket_route(path, route, name=name) - - def exception_handler( - self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] - ) -> typing.Callable: # pragma: nocover - def decorator(func: typing.Callable) -> typing.Callable: - self.add_exception_handler(exc_class_or_status_code, func) - return func - - return decorator - - def route( - self, - path: str, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, - include_in_schema: bool = True, - ) -> typing.Callable: # pragma: nocover - """ - We no longer document this decorator style API, and its usage is discouraged. - Instead you should use the following approach: - - routes = [ - Route(path, endpoint=..., ...), - ... - ] - - app = Starlette(routes=routes) - """ - - def decorator(func: typing.Callable) -> typing.Callable: - self.router.add_route( - path, - func, - methods=methods, - name=name, - include_in_schema=include_in_schema, - ) - return func - - return decorator - - def websocket_route( - self, path: str, name: typing.Optional[str] = None - ) -> typing.Callable: # pragma: nocover - """ - We no longer document this decorator style API, and its usage is discouraged. - Instead you should use the following approach: - - routes = [ - WebSocketRoute(path, endpoint=..., ...), - ... - ] - - app = Starlette(routes=routes) - """ - - def decorator(func: typing.Callable) -> typing.Callable: - self.router.add_websocket_route(path, func, name=name) - return func - - return decorator - - def middleware(self, middleware_type: str) -> typing.Callable: # pragma: nocover - """ - We no longer document this decorator style API, and its usage is discouraged. - Instead you should use the following approach: - - middleware = [ - Middleware(...), - ... - ] - - app = Starlette(middleware=middleware) - """ - - assert ( - middleware_type == "http" - ), 'Currently only middleware("http") is supported.' - - def decorator(func: typing.Callable) -> typing.Callable: - self.add_middleware(BaseHTTPMiddleware, dispatch=func) - return func - - return decorator diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py deleted file mode 100644 index bc71aecafe..0000000000 --- a/starlette/middleware/base.py +++ /dev/null @@ -1,122 +0,0 @@ -import typing -import warnings - -import anyio - -from starlette.requests import Request -from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Message, Receive, Scope, Send - -RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] -DispatchFunction = typing.Callable[ - [Request, RequestResponseEndpoint], typing.Awaitable[Response] -] -T = typing.TypeVar("T") - -warnings.warn( - "The 'BaseHTTPMiddleware' is deprecated, and will be removed in version 2.0.0. " - "Refer to https://www.starlette.io/middleware/#pure-asgi-middleware to learn " - "how to create middlewares.\nIf you need help, please create a discussion on: " - "https://github.com/encode/starlette/discussions.", - DeprecationWarning, -) - - -class BaseHTTPMiddleware: - def __init__( - self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None - ) -> None: - self.app = app - self.dispatch_func = self.dispatch if dispatch is None else dispatch - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - response_sent = anyio.Event() - - async def call_next(request: Request) -> Response: - app_exc: typing.Optional[Exception] = None - send_stream, recv_stream = anyio.create_memory_object_stream() - - async def receive_or_disconnect() -> Message: - if response_sent.is_set(): - return {"type": "http.disconnect"} - - async with anyio.create_task_group() as task_group: - - async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: - result = await func() - task_group.cancel_scope.cancel() - return result - - task_group.start_soon(wrap, response_sent.wait) - message = await wrap(request.receive) - - if response_sent.is_set(): - return {"type": "http.disconnect"} - - return message - - async def close_recv_stream_on_response_sent() -> None: - await response_sent.wait() - recv_stream.close() - - async def send_no_error(message: Message) -> None: - try: - await send_stream.send(message) - except anyio.BrokenResourceError: - # recv_stream has been closed, i.e. response_sent has been set. - return - - async def coro() -> None: - nonlocal app_exc - - async with send_stream: - try: - await self.app(scope, receive_or_disconnect, send_no_error) - except Exception as exc: - app_exc = exc - - task_group.start_soon(close_recv_stream_on_response_sent) - task_group.start_soon(coro) - - try: - message = await recv_stream.receive() - except anyio.EndOfStream: - if app_exc is not None: - raise app_exc - raise RuntimeError("No response returned.") - - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - async with recv_stream: - async for message in recv_stream: - assert message["type"] == "http.response.body" - body = message.get("body", b"") - if body: - yield body - if not message.get("more_body", False): - break - - if app_exc is not None: - raise app_exc - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response - - async with anyio.create_task_group() as task_group: - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, call_next) - await response(scope, receive, send) - response_sent.set() - - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - raise NotImplementedError() # pragma: no cover diff --git a/starlette/routing.py b/starlette/routing.py index 54c8f9a7e2..4186c014e5 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -721,119 +721,3 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def __eq__(self, other: typing.Any) -> bool: return isinstance(other, Router) and self.routes == other.routes - - # The following usages are now discouraged in favour of configuration - #  during Router.__init__(...) - def mount( - self, path: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: nocover - warnings.warn( - "The 'mount' method is now deprecated, and will be removed in version " - "2.0.0. Refer to https://www.starlette.io/routing/#submounting-routes " - "for recommended approach.", - DeprecationWarning, - ) - route = Mount(path, app=app, name=name) - self.routes.append(route) - - def host( - self, host: str, app: ASGIApp, name: typing.Optional[str] = None - ) -> None: # pragma: no cover - warnings.warn( - "The 'host' method is deprecated, and will be removed in version 2.0.0." - "Refer to https://www.starlette.io/routing/#host-based-routing for the " - "recommended approach.", - DeprecationWarning, - ) - route = Host(host, app=app, name=name) - self.routes.append(route) - - def add_route( - self, - path: str, - endpoint: typing.Callable, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, - include_in_schema: bool = True, - ) -> None: # pragma: nocover - route = Route( - path, - endpoint=endpoint, - methods=methods, - name=name, - include_in_schema=include_in_schema, - ) - self.routes.append(route) - - def add_websocket_route( - self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None - ) -> None: # pragma: no cover - route = WebSocketRoute(path, endpoint=endpoint, name=name) - self.routes.append(route) - - def route( - self, - path: str, - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, - include_in_schema: bool = True, - ) -> typing.Callable: # pragma: nocover - warnings.warn( - "The `route` decorator is deprecated, and will be removed in version 2.0.0." - "Refer to https://www.starlette.io/routing/#http-routing for the " - "recommended approach.", - DeprecationWarning, - ) - - def decorator(func: typing.Callable) -> typing.Callable: - self.add_route( - path, - func, - methods=methods, - name=name, - include_in_schema=include_in_schema, - ) - return func - - return decorator - - def websocket_route( - self, path: str, name: typing.Optional[str] = None - ) -> typing.Callable: # pragma: nocover - warnings.warn( - "The `websocket_route` decorator is deprecated, and will be removed in " - "version 2.0.0. Refer to " - "https://www.starlette.io/routing/#websocket-routing for the recommended " - "approach.", - DeprecationWarning, - ) - - def decorator(func: typing.Callable) -> typing.Callable: - self.add_websocket_route(path, func, name=name) - return func - - return decorator - - def add_event_handler( - self, event_type: str, func: typing.Callable - ) -> None: # pragma: no cover - assert event_type in ("startup", "shutdown") - - if event_type == "startup": - self.on_startup.append(func) - else: - self.on_shutdown.append(func) - - def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover - warnings.warn( - "The `on_event` decorator is deprecated, and will be removed in version " - "2.0.0. Refer to https://www.starlette.io/events/#registering-events for " - "recommended approach.", - DeprecationWarning, - ) - - def decorator(func: typing.Callable) -> typing.Callable: - self.add_event_handler(event_type, func) - return func - - return decorator diff --git a/starlette/testclient.py b/starlette/testclient.py index 978f8f506d..558e1c7809 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -6,7 +6,6 @@ import queue import sys import typing -import warnings from concurrent.futures import Future from types import GeneratorType from urllib.parse import unquote, urljoin @@ -401,29 +400,6 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No with anyio.start_blocking_portal(**self.async_backend) as portal: yield portal - def _choose_redirect_arg( - self, - follow_redirects: typing.Optional[bool], - allow_redirects: typing.Optional[bool], - ) -> typing.Union[bool, httpx._client.UseClientDefault]: - redirect: typing.Union[ - bool, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT - if allow_redirects is not None: - message = ( - "The `allow_redirects` argument is deprecated. " - "Use `follow_redirects` instead." - ) - warnings.warn(message, DeprecationWarning) - redirect = allow_redirects - if follow_redirects is not None: - redirect = follow_redirects - elif allow_redirects is not None and follow_redirects is not None: - raise RuntimeError( # pragma: no cover - "Cannot use both `allow_redirects` and `follow_redirects`." - ) - return redirect - def request( # type: ignore[override] self, method: str, @@ -439,15 +415,15 @@ def request( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: url = self.base_url.join(url) - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().request( method, url, @@ -459,7 +435,7 @@ def request( # type: ignore[override] headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -474,21 +450,21 @@ def get( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().get( url, params=params, headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -503,21 +479,21 @@ def options( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().options( url, params=params, headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -532,21 +508,21 @@ def head( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().head( url, params=params, headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -565,14 +541,14 @@ def post( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().post( url, content=content, @@ -583,7 +559,7 @@ def post( # type: ignore[override] headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -602,14 +578,14 @@ def put( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().put( url, content=content, @@ -620,7 +596,7 @@ def put( # type: ignore[override] headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -639,14 +615,14 @@ def patch( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().patch( url, content=content, @@ -657,7 +633,7 @@ def patch( # type: ignore[override] headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) @@ -672,21 +648,21 @@ def delete( # type: ignore[override] auth: typing.Union[ httpx._types.AuthTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, + follow_redirects: typing.Union[ + bool, httpx._client.UseClientDefault + ] = httpx._client.USE_CLIENT_DEFAULT, timeout: typing.Union[ httpx._client.TimeoutTypes, httpx._client.UseClientDefault ] = httpx._client.USE_CLIENT_DEFAULT, extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) return super().delete( url, params=params, headers=headers, cookies=cookies, auth=auth, - follow_redirects=redirect, + follow_redirects=follow_redirects, timeout=timeout, extensions=extensions, ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py deleted file mode 100644 index ed0734bd38..0000000000 --- a/tests/middleware/test_base.py +++ /dev/null @@ -1,415 +0,0 @@ -import contextvars -from contextlib import AsyncExitStack - -import anyio -import pytest - -from starlette.applications import Starlette -from starlette.background import BackgroundTask -from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.routing import Route, WebSocketRoute -from starlette.types import ASGIApp, Receive, Scope, Send - - -class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - response = await call_next(request) - response.headers["Custom-Header"] = "Example" - return response - - -def homepage(request): - return PlainTextResponse("Homepage") - - -def exc(request): - raise Exception("Exc") - - -def exc_stream(request): - return StreamingResponse(_generate_faulty_stream()) - - -def _generate_faulty_stream(): - yield b"Ok" - raise Exception("Faulty Stream") - - -class NoResponse: - def __init__(self, scope, receive, send): - pass - - def __await__(self): - return self.dispatch().__await__() - - async def dispatch(self): - pass - - -async def websocket_endpoint(session): - await session.accept() - await session.send_text("Hello, world!") - await session.close() - - -app = Starlette( - routes=[ - Route("/", endpoint=homepage), - Route("/exc", endpoint=exc), - Route("/exc-stream", endpoint=exc_stream), - Route("/no-response", endpoint=NoResponse), - WebSocketRoute("/ws", endpoint=websocket_endpoint), - ], - middleware=[Middleware(CustomMiddleware)], -) - - -def test_custom_middleware(test_client_factory): - client = test_client_factory(app) - response = client.get("/") - assert response.headers["Custom-Header"] == "Example" - - with pytest.raises(Exception) as ctx: - response = client.get("/exc") - assert str(ctx.value) == "Exc" - - with pytest.raises(Exception) as ctx: - response = client.get("/exc-stream") - assert str(ctx.value) == "Faulty Stream" - - with pytest.raises(RuntimeError): - response = client.get("/no-response") - - with client.websocket_connect("/ws") as session: - text = session.receive_text() - assert text == "Hello, world!" - - -def test_state_data_across_multiple_middlewares(test_client_factory): - expected_value1 = "foo" - expected_value2 = "bar" - - class aMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - request.state.foo = expected_value1 - response = await call_next(request) - return response - - class bMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - request.state.bar = expected_value2 - response = await call_next(request) - response.headers["X-State-Foo"] = request.state.foo - return response - - class cMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - response = await call_next(request) - response.headers["X-State-Bar"] = request.state.bar - return response - - def homepage(request): - return PlainTextResponse("OK") - - app = Starlette( - routes=[Route("/", homepage)], - middleware=[ - Middleware(aMiddleware), - Middleware(bMiddleware), - Middleware(cMiddleware), - ], - ) - - client = test_client_factory(app) - response = client.get("/") - assert response.text == "OK" - assert response.headers["X-State-Foo"] == expected_value1 - assert response.headers["X-State-Bar"] == expected_value2 - - -def test_app_middleware_argument(test_client_factory): - def homepage(request): - return PlainTextResponse("Homepage") - - app = Starlette( - routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] - ) - - client = test_client_factory(app) - response = client.get("/") - assert response.headers["Custom-Header"] == "Example" - - -def test_fully_evaluated_response(test_client_factory): - # Test for https://github.com/encode/starlette/issues/1022 - class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - await call_next(request) - return PlainTextResponse("Custom") - - app = Starlette(middleware=[Middleware(CustomMiddleware)]) - - client = test_client_factory(app) - response = client.get("/does_not_exist") - assert response.text == "Custom" - - -ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") - - -class CustomMiddlewareWithoutBaseHTTPMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - ctxvar.set("set by middleware") - await self.app(scope, receive, send) - assert ctxvar.get() == "set by endpoint" - - -class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - ctxvar.set("set by middleware") - resp = await call_next(request) - assert ctxvar.get() == "set by endpoint" - return resp # pragma: no cover - - -@pytest.mark.parametrize( - "middleware_cls", - [ - CustomMiddlewareWithoutBaseHTTPMiddleware, - pytest.param( - CustomMiddlewareUsingBaseHTTPMiddleware, - marks=pytest.mark.xfail( - reason=( - "BaseHTTPMiddleware creates a TaskGroup which copies the context" - "and erases any changes to it made within the TaskGroup" - ), - raises=AssertionError, - ), - ), - ], -) -def test_contextvars(test_client_factory, middleware_cls: type): - # this has to be an async endpoint because Starlette calls run_in_threadpool - # on sync endpoints which has it's own set of peculiarities w.r.t propagating - # contextvars (it propagates them forwards but not backwards) - async def homepage(request): - assert ctxvar.get() == "set by middleware" - ctxvar.set("set by endpoint") - return PlainTextResponse("Homepage") - - app = Starlette( - middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] - ) - - client = test_client_factory(app) - response = client.get("/") - assert response.status_code == 200, response.content - - -@pytest.mark.anyio -async def test_run_background_tasks_even_if_client_disconnects(): - # test for https://github.com/encode/starlette/issues/1438 - request_body_sent = False - response_complete = anyio.Event() - background_task_run = anyio.Event() - - async def sleep_and_set(): - # small delay to give BaseHTTPMiddleware a chance to cancel us - # this is required to make the test fail prior to fixing the issue - # so do not be surprised if you remove it and the test still passes - await anyio.sleep(0.1) - background_task_run.set() - - async def endpoint_with_background_task(_): - return PlainTextResponse(background=BackgroundTask(sleep_and_set)) - - async def passthrough(request, call_next): - return await call_next(request) - - app = Starlette( - middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], - routes=[Route("/", endpoint_with_background_task)], - ) - - scope = { - "type": "http", - "version": "3", - "method": "GET", - "path": "/", - } - - async def receive(): - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - # We simulate a client that disconnects immediately after receiving the response - await response_complete.wait() - return {"type": "http.disconnect"} - - async def send(message): - if message["type"] == "http.response.body": - if not message.get("more_body", False): - response_complete.set() - - await app(scope, receive, send) - - assert background_task_run.is_set() - - -@pytest.mark.anyio -async def test_run_context_manager_exit_even_if_client_disconnects(): - # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 - request_body_sent = False - response_complete = anyio.Event() - context_manager_exited = anyio.Event() - - async def sleep_and_set(): - # small delay to give BaseHTTPMiddleware a chance to cancel us - # this is required to make the test fail prior to fixing the issue - # so do not be surprised if you remove it and the test still passes - await anyio.sleep(0.1) - context_manager_exited.set() - - class ContextManagerMiddleware: - def __init__(self, app): - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - async with AsyncExitStack() as stack: - stack.push_async_callback(sleep_and_set) - await self.app(scope, receive, send) - - async def simple_endpoint(_): - return PlainTextResponse(background=BackgroundTask(sleep_and_set)) - - async def passthrough(request, call_next): - return await call_next(request) - - app = Starlette( - middleware=[ - Middleware(BaseHTTPMiddleware, dispatch=passthrough), - Middleware(ContextManagerMiddleware), - ], - routes=[Route("/", simple_endpoint)], - ) - - scope = { - "type": "http", - "version": "3", - "method": "GET", - "path": "/", - } - - async def receive(): - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - # We simulate a client that disconnects immediately after receiving the response - await response_complete.wait() - return {"type": "http.disconnect"} - - async def send(message): - if message["type"] == "http.response.body": - if not message.get("more_body", False): - response_complete.set() - - await app(scope, receive, send) - - assert context_manager_exited.is_set() - - -def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory): - class DiscardingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - await call_next(request) - return PlainTextResponse("Custom") - - async def downstream_app(scope, receive, send): - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/plain"), - ], - } - ) - async with anyio.create_task_group() as task_group: - - async def cancel_on_disconnect(): - while True: - message = await receive() - if message["type"] == "http.disconnect": - task_group.cancel_scope.cancel() - break - - task_group.start_soon(cancel_on_disconnect) - - # A timeout is set for 0.1 second in order to ensure that - # cancel_on_disconnect is scheduled by the event loop - with anyio.move_on_after(0.1): - while True: - await send( - { - "type": "http.response.body", - "body": b"chunk ", - "more_body": True, - } - ) - - pytest.fail( - "http.disconnect should have been received and canceled the scope" - ) # pragma: no cover - - app = DiscardingMiddleware(downstream_app) - - client = test_client_factory(app) - response = client.get("/does_not_exist") - assert response.text == "Custom" - - -def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory): - class DiscardingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): - await call_next(request) - return PlainTextResponse("Custom") - - async def downstream_app(scope, receive, send): - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/plain"), - ], - } - ) - await send( - { - "type": "http.response.body", - "body": b"first chunk, ", - "more_body": True, - } - ) - await send( - { - "type": "http.response.body", - "body": b"second chunk", - "more_body": True, - } - ) - message = await receive() - assert message["type"] == "http.disconnect" - - app = DiscardingMiddleware(downstream_app) - - client = test_client_factory(app) - response = client.get("/does_not_exist") - assert response.text == "Custom" diff --git a/tests/test_responses.py b/tests/test_responses.py index 608842da2e..b9105f8e65 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -84,7 +84,7 @@ async def app(scope, receive, send): await response(scope, receive, send) client: TestClient = test_client_factory(app) - response = client.request("GET", "/redirect", allow_redirects=False) + response = client.request("GET", "/redirect", follow_redirects=False) assert response.url == "http://testserver/redirect" assert response.headers["content-length"] == "0" diff --git a/tests/test_routing.py b/tests/test_routing.py index e2c2eb2c4b..cd4af37022 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -864,24 +864,6 @@ def test_mount_asgi_app_with_middleware_url_path_for() -> None: mounted_app_with_middleware.url_path_for("route") -def test_add_route_to_app_after_mount( - test_client_factory: typing.Callable[..., TestClient], -) -> None: - """Checks that Mount will pick up routes - added to the underlying app after it is mounted - """ - inner_app = Router() - app = Mount("/http", app=inner_app) - inner_app.add_route( - "/inner", - endpoint=homepage, - methods=["GET"], - ) - client = test_client_factory(app) - response = client.get("/http/inner") - assert response.status_code == 200 - - def test_exception_on_mounted_apps(test_client_factory): def exc(request): raise Exception("Exc") diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 142c2a00b5..c3f74b468b 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -8,9 +8,6 @@ from starlette.applications import Starlette from starlette.exceptions import HTTPException -from starlette.middleware import Middleware -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request from starlette.routing import Mount from starlette.staticfiles import StaticFiles @@ -40,28 +37,6 @@ def test_staticfiles_with_pathlib(tmpdir, test_client_factory): assert response.text == "" -def test_staticfiles_head_with_middleware(tmpdir, test_client_factory): - """ - see https://github.com/encode/starlette/pull/935 - """ - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: - file.write("x" * 100) - - async def does_nothing_middleware(request: Request, call_next): - response = await call_next(request) - return response - - routes = [Mount("/static", app=StaticFiles(directory=tmpdir), name="static")] - middleware = [Middleware(BaseHTTPMiddleware, dispatch=does_nothing_middleware)] - app = Starlette(routes=routes, middleware=middleware) - - client = test_client_factory(app) - response = client.head("/static/example.txt") - assert response.status_code == 200 - assert response.headers.get("content-length") == "100" - - def test_staticfiles_with_package(test_client_factory): app = StaticFiles(packages=["tests"]) client = test_client_factory(app)