diff --git a/starlette/applications.py b/starlette/applications.py index a46cbaa0e..c68ad864a 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -65,7 +65,7 @@ def __init__( on_startup is None and on_shutdown is None ), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." - self._debug = debug + self.debug = debug self.state = State() self.router = Router( routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan @@ -74,7 +74,7 @@ def __init__( {} if exception_handlers is None else dict(exception_handlers) ) self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack = self.build_middleware_stack() + self.middleware_stack: typing.Optional[ASGIApp] = None def build_middleware_stack(self) -> ASGIApp: debug = self.debug @@ -108,20 +108,13 @@ def build_middleware_stack(self) -> ASGIApp: def routes(self) -> typing.List[BaseRoute]: return self.router.routes - @property - def debug(self) -> bool: - return self._debug - - @debug.setter - def debug(self, value: bool) -> None: - self._debug = value - self.middleware_stack = self.build_middleware_stack() - def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: return self.router.url_path_for(name, **path_params) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self + if self.middleware_stack is None: + self.middleware_stack = self.build_middleware_stack() await self.middleware_stack(scope, receive, send) def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover @@ -137,11 +130,10 @@ def host( ) -> None: # pragma: no cover self.router.host(host, app=app, name=name) - def add_middleware( - self, middleware_class: type, **options: typing.Any - ) -> None: # pragma: no cover + def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + if self.middleware_stack is not None: # pragma: no cover + raise RuntimeError("Cannot add middleware after an application has started") self.user_middleware.insert(0, Middleware(middleware_class, **options)) - self.middleware_stack = self.build_middleware_stack() def add_exception_handler( self, @@ -149,7 +141,6 @@ def add_exception_handler( handler: typing.Callable, ) -> None: # pragma: no cover self.exception_handlers[exc_class_or_status_code] = handler - self.middleware_stack = self.build_middleware_stack() def add_event_handler( self, event_type: str, func: typing.Callable diff --git a/tests/test_applications.py b/tests/test_applications.py index fcacbe633..ba10aff8e 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,7 +1,9 @@ import os from contextlib import asynccontextmanager +from typing import Any, Callable import anyio +import httpx import pytest from starlette import status @@ -13,6 +15,7 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.types import ASGIApp from starlette.websockets import WebSocket @@ -486,3 +489,45 @@ async def startup(): app.on_event("startup")(startup) assert len(record) == 1 + + +def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]): + class NoOpMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, *args: Any): + await self.app(*args) + + class SimpleInitializableMiddleware: + counter = 0 + + def __init__(self, app: ASGIApp): + self.app = app + SimpleInitializableMiddleware.counter += 1 + + async def __call__(self, *args: Any): + await self.app(*args) + + def get_app() -> ASGIApp: + app = Starlette() + app.add_middleware(SimpleInitializableMiddleware) + app.add_middleware(NoOpMiddleware) + return app + + app = get_app() + + with test_client_factory(app): + pass + + assert SimpleInitializableMiddleware.counter == 1 + + test_client_factory(app).get("/foo") + + assert SimpleInitializableMiddleware.counter == 1 + + app = get_app() + + test_client_factory(app).get("/foo") + + assert SimpleInitializableMiddleware.counter == 2