Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mount(..., middleware=[...]) #1649

Merged
merged 34 commits into from Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
42548cf
Add Mount(..., middleware=[...])
adriangb May 24, 2022
01bbcdc
fmt
adriangb May 24, 2022
d8b626d
add missing test
adriangb May 24, 2022
73dc39e
fmt
adriangb May 24, 2022
14d3005
move docs
adriangb May 24, 2022
8eb2699
Merge branch 'master' into mount-middleware
adriangb May 24, 2022
baab334
Merge branch 'master' into mount-middleware
adriangb May 27, 2022
bbca389
Merge remote-tracking branch 'upstream/master' into mount-middleware
adriangb May 27, 2022
a75a523
replace basehttpmiddleware
adriangb May 27, 2022
578f618
Merge branch 'mount-middleware' of https://github.com/adriangb/starle…
adriangb May 27, 2022
ba59a35
Merge branch 'master' into mount-middleware
adriangb May 30, 2022
3cadfb2
Merge branch 'master' into mount-middleware
adriangb Jun 2, 2022
1eace2b
Merge branch 'master' into mount-middleware
adriangb Jun 8, 2022
ca50340
Merge branch 'master' into mount-middleware
adriangb Jun 15, 2022
f65dfc8
lint
adriangb Jun 15, 2022
5edb100
Merge branch 'master' into mount-middleware
adriangb Jul 2, 2022
1ef66e6
Merge branch 'master' into mount-middleware
adriangb Jul 12, 2022
fb93ef5
Merge branch 'master' into mount-middleware
adriangb Jul 20, 2022
273cc73
Merge remote-tracking branch 'upstream/master' into mount-middleware
adriangb Aug 13, 2022
4369ee7
add comment to docs
adriangb Aug 13, 2022
10f47ae
save
adriangb Aug 13, 2022
feeba5e
Merge branch 'master' into mount-middleware
Kludex Aug 18, 2022
0bb54e4
Merge branch 'master' into mount-middleware
adriangb Aug 31, 2022
d7c3f2a
Update docs/middleware.md
adriangb Sep 1, 2022
e6fad81
Update docs/middleware.md
adriangb Sep 1, 2022
adb52ab
Merge branch 'master' into mount-middleware
adriangb Sep 1, 2022
0f7a2c4
Merge branch 'master' into mount-middleware
Kludex Sep 3, 2022
f6de20f
Apply suggestions from code review
Kludex Sep 3, 2022
72fbaa6
fix duplicate parametrization
adriangb Sep 3, 2022
eee6a6f
fix bad test for url_path_for
adriangb Sep 3, 2022
aec580f
Add test explaining behavior
adriangb Sep 3, 2022
93ec37f
fix linting
adriangb Sep 6, 2022
5f936e9
Merge branch 'master' into mount-middleware
adriangb Sep 6, 2022
0079756
Merge branch 'master' into mount-middleware
abersheeran Sep 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/middleware.md
Expand Up @@ -683,6 +683,41 @@ to use the `middleware=<List of Middleware instances>` style, as it will:
* Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
* Preserves the top-level `app` instance.

## Applying middleware to `Mount`s

Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application:

```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Mount, Route


routes = [
Mount(
"/",
routes=[
Route(
"/example",
endpoint=...,
)
],
middleware=[Middleware(GZipMiddleware)]
)
]

app = Starlette(routes=routes)
```

Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is.
This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses.
If you do want to apply the middelware logic to error responses only on some routes you have a couple of options:
Kludex marked this conversation as resolved.
Show resolved Hide resolved

* Add an `ExceptionMiddleware` onto the `Mount`
* Add a `try/except` block to your middleware and return an error response from there
* Split up marking and processing into two middlewares, one that gets put on `Mount` which simply marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... This works... 🤔

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route


async def home(request):
    return Response("Hi there!")


async def exception_handler(request, exc):
    return Response("I have a 400 for you!", status_code=400)


class PotatoException(Exception):
    ...


class CustomMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        raise PotatoException()


routes = [
    Mount(
        "/",
        routes=[
            Route(
                "/example",
                endpoint=home,
            )
        ],
        middleware=[Middleware(CustomMiddleware)],
    ),
]
app = Starlette(routes=routes, exception_handlers={PotatoException: exception_handler})

Copy link
Member Author

@adriangb adriangb Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try this example:

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route

class PotatoException(Exception):
    ...

async def bad_endpoint(request):
    raise PotatoException

async def exception_handler(request, exc):
    return Response("I have a 400 for you!", status_code=400)

class CustomMiddleware:
    def __init__(self, app, name: str):
        self.app = app
        self.name = name

    async def __call__(self, scope, receive, send):
        async def wrapped_send(msg):
            if msg["type"] == "http.response":
                print(f"{self.name} called for {scope['raw_path']}")
            await send(msg)
        await self.app(scope, receive, wrapped_send)

routes = [
    Mount(
        "/mount",
        routes=[
            Route(
                "/bad",
                endpoint=bad_endpoint,
            )
        ],
        middleware=[Middleware(CustomMiddleware, name="on_mount")],
    ),
    Route(
        "/good",
        endpoint=bad_endpoint,
    )
]

app = Starlette(
    routes=routes,
    exception_handlers={PotatoException: exception_handler},
    middleware=[Middleware(CustomMiddleware, name="on_app")],
)

The "on_mount" middleware will never be called because PotatoException tears through Mount's middleware stack. On the other hand "on_app" always gets called because it's "protected" by ErrorMiddleware.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. But that's not what I got from what's written here...

Do you think an image with middleware/app/mount as "blocks" would be helpful understanding this? 🤔

(I can help if you think it makes sense, but I lack design skills to do it 😎 👍)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good idea, please give it a shot! I won't have time until next week (and also don't have the design skills 😆). Maybe do it with ascii text instead of an image since it will be easier to embed?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a blocker. 👍
I'll not have time to do this, if someone eventually wants, feel free to open a PR with it. 👍

(Please do not resolve this conversation so ppl see it)

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right.

Adrian's example was missing something (I had to add it myself before, and I forgot)...

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route

class PotatoException(Exception):
    ...

async def bad_endpoint(request):
    raise PotatoException

async def exception_handler(request, exc):
    return Response("I have a 400 for you!", status_code=400)

class CustomMiddleware:
    def __init__(self, app, name: str):
        self.app = app
        self.name = name

    async def __call__(self, scope, receive, send):
        async def wrapped_send(msg):
            if msg["type"] == "http.response.start":  # IT WAS MISSING `.start`
                print(f"{self.name} called for {scope['raw_path']}")
            await send(msg)
        await self.app(scope, receive, wrapped_send)

routes = [
    Mount(
        "/mount",
        routes=[
            Route(
                "/bad",
                endpoint=bad_endpoint,
            )
        ],
        middleware=[Middleware(CustomMiddleware, name="on_mount")],
    ),
    Route(
        "/good",
        endpoint=bad_endpoint,
    )
]

app = Starlette(
    routes=routes,
    exception_handlers={PotatoException: exception_handler},
    middleware=[Middleware(CustomMiddleware, name="on_app")],
)

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use @florimondmanca 's drawing here for the explanation 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the broken example 😫, but yes you got it right Florimond!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test in aec580f

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@florimondmanca I'm realizing now that it was not clear if this was completely resolved, sorry if I missed it. Do you think we should tweak the docs more, maybe adding something along the lines of your explanation in #1649 (comment)?

Kludex marked this conversation as resolved.
Show resolved Hide resolved

## Third party middleware

#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
Expand Down
13 changes: 10 additions & 3 deletions starlette/routing.py
Expand Up @@ -14,6 +14,7 @@
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
Expand Down Expand Up @@ -348,24 +349,30 @@ def __init__(
app: typing.Optional[ASGIApp] = None,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
name: typing.Optional[str] = None,
*,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self.app: ASGIApp = app
self._base_app: ASGIApp = app
else:
self.app = Router(routes=routes)
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)

@property
def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", [])
return getattr(self._base_app, "routes", [])

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
Expand Down
109 changes: 108 additions & 1 deletion tests/test_routing.py
Expand Up @@ -5,8 +5,20 @@
import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.routing import (
BaseRoute,
Host,
Mount,
NoMatchFound,
Route,
Router,
WebSocketRoute,
)
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect


Expand Down Expand Up @@ -768,6 +780,101 @@ def test_route_name(endpoint: typing.Callable, expected_name: str):
assert Route(path="/", endpoint=endpoint).name == expected_name


class AddHeadersMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["add_headers_middleware"] = True

async def modified_send(msg: Message) -> None:
if msg["type"] == "http.response.start":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)

await self.app(scope, receive, modified_send)


def assert_middleware_header_route(request: Request) -> Response:
assert request.scope["add_headers_middleware"] is True
return Response()


mounted_routes_with_middleware = Mount(
"/http",
routes=[
Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
],
middleware=[Middleware(AddHeadersMiddleware)],
)


mounted_app_with_middleware = Mount(
"/http",
app=Route(
"/",
endpoint=assert_middleware_header_route,
methods=["GET"],
name="route",
),
middleware=[Middleware(AddHeadersMiddleware)],
)


@pytest.mark.parametrize(
"route",
[
mounted_routes_with_middleware,
mounted_routes_with_middleware,
Kludex marked this conversation as resolved.
Show resolved Hide resolved
mounted_app_with_middleware,
],
)
def test_mount_middleware(
test_client_factory: typing.Callable[..., TestClient],
route: BaseRoute,
) -> None:
test_client = test_client_factory(Router([route]))
response = test_client.get("/http")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also assert that calling a route not under the Mount does not apply the middleware? E.g. testing a test_client.get("/").

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in eee6a6f

assert response.status_code == 200
assert response.headers["X-Test"] == "Set by middleware"


@pytest.mark.parametrize(
"route",
[
mounted_routes_with_middleware,
mounted_routes_with_middleware,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated as well, so I think we can get rid of the @parametrize.

Copy link
Member Author

@adriangb adriangb Sep 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in eee6a6f

],
)
def test_mount_middleware_url_path_for(route: BaseRoute) -> None:
"""Checks that url_path_for still works with middleware on Mounts"""
router = Router([route])
assert router.url_path_for("route") == "/http/"


def test_add_route_to_app_after_mount(
test_client_factory: typing.Callable[..., TestClient],
) -> None:
"""Checks that Mount will pick up routes
added to the underlying app after it is mounted
"""
inner_app = Router()
app = Mount("/http", app=inner_app)
inner_app.add_route(
"/inner",
endpoint=lambda request: Response(),
methods=["GET"],
)
client = test_client_factory(app)
response = client.get("/http/inner")
assert response.status_code == 200


def test_exception_on_mounted_apps(test_client_factory):
def exc(request):
raise Exception("Exc")
Expand Down