Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazily build middleware stack #2017

Merged
merged 11 commits into from Feb 6, 2023
11 changes: 7 additions & 4 deletions starlette/applications.py
Expand Up @@ -74,7 +74,7 @@ def __init__(
{} if exception_handlers is None else dict(exception_handlers)
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()
self.middleware_stack: typing.Optional[ASGIApp] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, build_middleware_stack is a very cheap function to call. It makes sense to call it here and call once again when you need it again, and avoid nullable attributes (and null-checks).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moreover, it is not expected to be called often at runtime and affects only startup (configuration) time.
So calling it multiple times seems ok to me.

Copy link
Member Author

@adriangb adriangb Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure if you advocating for the old existing approach with the one being proposed here. Performance is not the goal, it is to avoid instantiating users middleware repeatedly as reported in #2002


def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
Expand Down Expand Up @@ -115,13 +115,14 @@ def debug(self) -> bool:
@debug.setter
def debug(self, value: bool) -> None:
self._debug = value
adriangb marked this conversation as resolved.
Show resolved Hide resolved
self.middleware_stack = self.build_middleware_stack()

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(name, **path_params)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
Comment on lines +116 to +117
Copy link
Member Author

@adriangb adriangb Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will run when the lifespan is triggered most of the time but if the lifespan is disable it will run on the first request. So not checking scope["type"] is by design.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've mentioned this in the description as well. 👍

Comment on lines +116 to +117
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
if scope["type"] != "lifespan" and self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the middleware runs something on lifespan? Maybe we need to build once before running the lifespan and then, if it's changed, we re-build it after running the lifespan?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we want both scenarios...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah :P

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

The same thing as your first message.

But actually... We don't want to rebuild, otherwise we introduce the issue that motivated this PR again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I have a solution that allows dynamically adding middleware at any point in time and never re-building:

class ASGIAppProxy:
    def __init__(self, app: ASGIApp) -> None:
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        await self.app(scope, receive, send)

The reason we need to re-build in the first place is that the last user middleware needs to point to self.router/ExceptionMiddleware. So say you have a single A user middelware, the "stack" looks like ServerErrorMiddleware -> UserMiddlewareA -> ExceptionMiddleware -> Router. Then you add a UserMiddlewareB and it should be ServerErrorMiddleware -> UserMiddlewareA -> UserMiddlewareB -> ExceptionMiddleware -> Router. So you need to change the ASGIApp that UserMiddlewareA is pointing to from ExceptionMiddleware to UserMiddlewareB. We can't just assign a .app attribute on the users middleware because it could be anything. By introducing the above wrapper we can do something like:

exc_middleware = ExceptionMiddleware(...)
tail = ASGIAppProxy(exc_middleware)
user_middleware_tail = UserMiddlewareA(tail)
# add UserMiddlewareB
new_tail = ASGIAppProxy(exc_middleware)
new_user_middleware_tail = UserMiddlewareB(new_tail)
tail.app = new_user_middleware_tail
tail = new_tail

await self.middleware_stack(scope, receive, send)

def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
Expand All @@ -140,16 +141,18 @@ def host(
def add_middleware(
self, middleware_class: type, **options: typing.Any
) -> None: # pragma: no cover
adriangb marked this conversation as resolved.
Show resolved Hide resolved
if self.middleware_stack is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is deprecated, isn't it @Kludex ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. #2002 still came up last meeting as a problem.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not...

We didn't deprecate it. We only deprecated the decorators. 🤔

And "what came up on the meeting" was that I mentioned the add_middleware as an issue, due to the linked GitHub issue above, and the idea was to deprecate, and further remove, but Adrian came up with this alternative solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, thank you for clarifying 😄

raise RuntimeError(
"Cannot add middlewares after an application has started"
)
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.middleware_stack = self.build_middleware_stack()

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
self.middleware_stack = self.build_middleware_stack()

def add_event_handler(
self, event_type: str, func: typing.Callable
Expand Down
45 changes: 45 additions & 0 deletions tests/test_applications.py
@@ -1,7 +1,9 @@
import os
from contextlib import asynccontextmanager
from typing import Any, Callable

import anyio
import httpx
import pytest

from starlette import status
Expand All @@ -13,6 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -486,3 +489,45 @@ async def startup():

app.on_event("startup")(startup)
assert len(record) == 1


def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be the first time using httpx.Client on this annotation on the test suite 👀

Suggested change
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], TestClient]):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah I should have used TestClient. Luckily it's just a type annotation.

class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)

class SimpleInitializableMiddleware:
counter = 0

def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)

def get_app() -> ASGIApp:
app = Starlette()
app.add_middleware(SimpleInitializableMiddleware)
app.add_middleware(NoOpMiddleware)
return app

app = get_app()

with test_client_factory(app):
pass

assert SimpleInitializableMiddleware.counter == 1

test_client_factory(app).get("/foo")

assert SimpleInitializableMiddleware.counter == 1

app = get_app()

test_client_factory(app).get("/foo")

assert SimpleInitializableMiddleware.counter == 2
Comment on lines +529 to +533
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How relevant is this part for the test? 🤔