From 453cb8e95230887ffcf48bc6cbf36762676200f1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 30 Mar 2023 14:34:37 +0200 Subject: [PATCH] Use a proxy to modify middleware chain in place Superseeds #2017 and fixes https://github.com/encode/starlette/pull/2017#issuecomment-1488537008 --- starlette/applications.py | 75 ++++++++++++++++++++++++++-------- starlette/middleware/errors.py | 2 +- tests/test_applications.py | 49 +++++++++++++++++++++- 3 files changed, 107 insertions(+), 19 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 013364be3..8063fd9e1 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -14,6 +14,18 @@ AppType = typing.TypeVar("AppType", bound="Starlette") +class _ASGIAppProxy: + """A proxy into an ASGI app that we can re-assign to point to a new + app without modifying any references to the _ASGIAppProxy itself. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + class Starlette: """ Creates an application instance. @@ -76,8 +88,18 @@ def __init__( self.exception_handlers = ( {} if exception_handlers is None else dict(exception_handlers) ) + self._pending_exception_handlers = False self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack: typing.Optional[ASGIApp] = None + self._pending_user_middlewares = self.user_middleware.copy() + # wrap ExceptionMiddleware in a proxy so that + # we can re-build it when exception handlers get added + self._exception_middleware = _ASGIAppProxy( + ExceptionMiddleware(self.router, debug=self.debug) + ) + + self._user_middleware_outer: _ASGIAppProxy | None = None + self._user_middleware_inner = _ASGIAppProxy(self._exception_middleware.app) + self.middleware_stack: ASGIApp | None = None def build_middleware_stack(self) -> ASGIApp: debug = self.debug @@ -92,20 +114,34 @@ def build_middleware_stack(self) -> ASGIApp: else: exception_handlers[key] = value - middleware = ( - [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] - + self.user_middleware - + [ - Middleware( - ExceptionMiddleware, handlers=exception_handlers, debug=debug - ) - ] + self._exception_middleware.app = ExceptionMiddleware( + self.router, handlers=exception_handlers, debug=debug + ) + self._pending_exception_handlers = False + + user_middleware = self._pending_user_middlewares.copy() + self._pending_user_middlewares.clear() + + user_middleware_outer: ASGIApp + if user_middleware: + # build a new middleware chain that wraps self._exception_middleware + app: ASGIApp + app = new_user_middleware_inner = _ASGIAppProxy(self._exception_middleware) + for cls, options in reversed(user_middleware): + app = cls(app=app, **options) + # and set the .app for the previous innermost user middleware to point + # to this new chain + self._user_middleware_inner.app = app + # then replace our innermost user middleware + self._user_middleware_inner = new_user_middleware_inner + + user_middleware_outer = self._user_middleware_outer or app + else: + user_middleware_outer = self._exception_middleware + + return ServerErrorMiddleware( + app=user_middleware_outer, handler=error_handler, debug=debug ) - - app = self.router - for cls, options in reversed(middleware): - app = cls(app=app, **options) - return app @property def routes(self) -> typing.List[BaseRoute]: @@ -117,8 +153,13 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self - if self.middleware_stack is None: + if ( + self.middleware_stack is None + or self._pending_exception_handlers + or self._pending_user_middlewares + ): self.middleware_stack = self.build_middleware_stack() + assert self.middleware_stack is not None await self.middleware_stack(scope, receive, send) def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover @@ -135,8 +176,7 @@ def host( self.router.host(host, app=app, name=name) def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: - if self.middleware_stack is not None: # pragma: no cover - raise RuntimeError("Cannot add middleware after an application has started") + self._pending_user_middlewares.append(Middleware(middleware_class, **options)) self.user_middleware.insert(0, Middleware(middleware_class, **options)) def add_exception_handler( @@ -144,6 +184,7 @@ def add_exception_handler( exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable, ) -> None: # pragma: no cover + self._pending_exception_handlers = True self.exception_handlers[exc_class_or_status_code] = handler def add_event_handler( diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index b9d9c6910..19c39b967 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -137,7 +137,7 @@ class ServerErrorMiddleware: def __init__( self, app: ASGIApp, - handler: typing.Optional[typing.Callable] = None, + handler: typing.Optional[typing.Callable[..., typing.Any]] = None, debug: bool = False, ) -> None: self.app = app diff --git a/tests/test_applications.py b/tests/test_applications.py index e30ec9295..fc7cd0bf7 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,6 +1,6 @@ import os from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable +from typing import Any, AsyncIterator, Callable, Set import anyio import httpx @@ -548,3 +548,50 @@ async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover yield App(lifespan=lifespan) + + +def test_lifespan_modifies_exc_handlers( + test_client_factory: Callable[[ASGIApp], httpx.Client] +): + class NoOpMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, *args: Any): + await self.app(*args) + + class SimpleInitializableMiddleware: + counter = 0 + + def __init__(self, app: ASGIApp): + self.app = app + instances.add(self) + SimpleInitializableMiddleware.counter += 1 + + async def __call__(self, *args: Any): + await self.app(*args) + + instances: Set[SimpleInitializableMiddleware] = set() + + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + app.add_middleware( + SimpleInitializableMiddleware + ) # should cause the stack to be re-built + yield + + def get_app() -> ASGIApp: + app = Starlette(lifespan=lifespan) + app.add_middleware(SimpleInitializableMiddleware) + app.add_middleware(NoOpMiddleware) + return app + + app = get_app() + + with test_client_factory(app) as client: + assert len(instances) == 1 + assert SimpleInitializableMiddleware.counter == 1 + # next request rebuilds + client.get("/does-not-matter") + assert len(instances) == 2 + assert SimpleInitializableMiddleware.counter == 2