diff --git a/docs/middleware.md b/docs/middleware.md index 5d6a32f3a..e42ada493 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -686,6 +686,41 @@ to use the `middleware=` style, as it will: * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`. * Preserves the top-level `app` instance. +## Applying middleware to `Mount`s + +Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application: + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.gzip import GZipMiddleware +from starlette.routing import Mount, Route + + +routes = [ + Mount( + "/", + routes=[ + Route( + "/example", + endpoint=..., + ) + ], + middleware=[Middleware(GZipMiddleware)] + ) +] + +app = Starlette(routes=routes) +``` + +Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is. +This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses. +If you do want to apply the middleware logic to error responses only on some routes you have a couple of options: + +* Add an `ExceptionMiddleware` onto the `Mount` +* Add a `try/except` block to your middleware and return an error response from there +* Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting. + ## Third party middleware #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github) diff --git a/starlette/routing.py b/starlette/routing.py index 1aa2cdb6d..2c6965be0 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -14,6 +14,7 @@ from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send @@ -348,6 +349,8 @@ def __init__( app: typing.Optional[ASGIApp] = None, routes: typing.Optional[typing.Sequence[BaseRoute]] = None, name: typing.Optional[str] = None, + *, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( @@ -355,9 +358,13 @@ def __init__( ), "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") if app is not None: - self.app: ASGIApp = app + self._base_app: ASGIApp = app else: - self.app = Router(routes=routes) + self._base_app = Router(routes=routes) + self.app = self._base_app + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" @@ -365,7 +372,7 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: - return getattr(self.app, "routes", []) + return getattr(self._base_app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): diff --git a/tests/test_routing.py b/tests/test_routing.py index e3b1e412a..750f32496 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -5,8 +5,13 @@ import pytest from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.middleware import Middleware +from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.testclient import TestClient +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect @@ -768,6 +773,115 @@ def test_route_name(endpoint: typing.Callable, expected_name: str): assert Route(path="/", endpoint=endpoint).name == expected_name +class AddHeadersMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope["add_headers_middleware"] = True + + async def modified_send(msg: Message) -> None: + if msg["type"] == "http.response.start": + msg["headers"].append((b"X-Test", b"Set by middleware")) + await send(msg) + + await self.app(scope, receive, modified_send) + + +def assert_middleware_header_route(request: Request) -> Response: + assert request.scope["add_headers_middleware"] is True + return Response() + + +mounted_routes_with_middleware = Starlette( + routes=[ + Mount( + "/http", + routes=[ + Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + ), + ], + middleware=[Middleware(AddHeadersMiddleware)], + ), + Route("/home", homepage), + ] +) + + +mounted_app_with_middleware = Starlette( + routes=[ + Mount( + "/http", + app=Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + ), + middleware=[Middleware(AddHeadersMiddleware)], + ), + Route("/home", homepage), + ] +) + + +@pytest.mark.parametrize( + "app", + [ + mounted_routes_with_middleware, + mounted_app_with_middleware, + ], +) +def test_mount_middleware( + test_client_factory: typing.Callable[..., TestClient], + app: Starlette, +) -> None: + test_client = test_client_factory(app) + + response = test_client.get("/home") + assert response.status_code == 200 + assert "X-Test" not in response.headers + + response = test_client.get("/http") + assert response.status_code == 200 + assert response.headers["X-Test"] == "Set by middleware" + + +def test_mount_routes_with_middleware_url_path_for() -> None: + """Checks that url_path_for still works with mounted routes with Middleware""" + assert mounted_routes_with_middleware.url_path_for("route") == "/http/" + + +def test_mount_asgi_app_with_middleware_url_path_for() -> None: + """Mounted ASGI apps do not work with url path for, + middleware does not change this + """ + with pytest.raises(NoMatchFound): + mounted_app_with_middleware.url_path_for("route") + + +def test_add_route_to_app_after_mount( + test_client_factory: typing.Callable[..., TestClient], +) -> None: + """Checks that Mount will pick up routes + added to the underlying app after it is mounted + """ + inner_app = Router() + app = Mount("/http", app=inner_app) + inner_app.add_route( + "/inner", + endpoint=homepage, + methods=["GET"], + ) + client = test_client_factory(app) + response = client.get("/http/inner") + assert response.status_code == 200 + + def test_exception_on_mounted_apps(test_client_factory): def exc(request): raise Exception("Exc") @@ -779,3 +893,62 @@ def exc(request): with pytest.raises(Exception) as ctx: client.get("/sub/") assert str(ctx.value) == "Exc" + + +def test_mounted_middleware_does_not_catch_exception( + test_client_factory: typing.Callable[..., TestClient], +) -> None: + # https://github.com/encode/starlette/pull/1649#discussion_r960236107 + def exc(request: Request) -> Response: + raise HTTPException(status_code=403, detail="auth") + + class NamedMiddleware: + def __init__(self, app: ASGIApp, name: str) -> None: + self.app = app + self.name = name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def modified_send(msg: Message) -> None: + if msg["type"] == "http.response.start": + msg["headers"].append((f"X-{self.name}".encode(), b"true")) + await send(msg) + + await self.app(scope, receive, modified_send) + + app = Starlette( + routes=[ + Mount( + "/mount", + routes=[ + Route("/err", exc), + Route("/home", homepage), + ], + middleware=[Middleware(NamedMiddleware, name="Mounted")], + ), + Route("/err", exc), + Route("/home", homepage), + ], + middleware=[Middleware(NamedMiddleware, name="Outer")], + ) + + client = test_client_factory(app) + + resp = client.get("/home") + assert resp.status_code == 200, resp.content + assert "X-Outer" in resp.headers + + resp = client.get("/err") + assert resp.status_code == 403, resp.content + assert "X-Outer" in resp.headers + + resp = client.get("/mount/home") + assert resp.status_code == 200, resp.content + assert "X-Mounted" in resp.headers + + # this is the "surprising" behavior bit + # the middleware on the mount never runs because there + # is nothing to catch the HTTPException + # since Mount middlweare is not wrapped by ExceptionMiddleware + resp = client.get("/mount/err") + assert resp.status_code == 403, resp.content + assert "X-Mounted" not in resp.headers