Skip to content

Commit

Permalink
Lazily build middleware stack (#2017)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
2 people authored and aminalaee committed Feb 13, 2023
1 parent d04e3f8 commit 3ae673b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 16 deletions.
23 changes: 7 additions & 16 deletions starlette/applications.py
Expand Up @@ -65,7 +65,7 @@ def __init__(
on_startup is None and on_shutdown is None
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."

self._debug = debug
self.debug = debug
self.state = State()
self.router = Router(
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
Expand All @@ -74,7 +74,7 @@ def __init__(
{} if exception_handlers is None else dict(exception_handlers)
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()
self.middleware_stack: typing.Optional[ASGIApp] = None

def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
Expand Down Expand Up @@ -108,20 +108,13 @@ def build_middleware_stack(self) -> ASGIApp:
def routes(self) -> typing.List[BaseRoute]:
return self.router.routes

@property
def debug(self) -> bool:
return self._debug

@debug.setter
def debug(self, value: bool) -> None:
self._debug = value
self.middleware_stack = self.build_middleware_stack()

def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(name, **path_params)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)

def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
Expand All @@ -137,19 +130,17 @@ def host(
) -> None: # pragma: no cover
self.router.host(host, app=app, name=name)

def add_middleware(
self, middleware_class: type, **options: typing.Any
) -> None: # pragma: no cover
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.middleware_stack = self.build_middleware_stack()

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
self.middleware_stack = self.build_middleware_stack()

def add_event_handler(
self, event_type: str, func: typing.Callable
Expand Down
45 changes: 45 additions & 0 deletions tests/test_applications.py
@@ -1,7 +1,9 @@
import os
from contextlib import asynccontextmanager
from typing import Any, Callable

import anyio
import httpx
import pytest

from starlette import status
Expand All @@ -13,6 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -486,3 +489,45 @@ async def startup():

app.on_event("startup")(startup)
assert len(record) == 1


def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)

class SimpleInitializableMiddleware:
counter = 0

def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)

def get_app() -> ASGIApp:
app = Starlette()
app.add_middleware(SimpleInitializableMiddleware)
app.add_middleware(NoOpMiddleware)
return app

app = get_app()

with test_client_factory(app):
pass

assert SimpleInitializableMiddleware.counter == 1

test_client_factory(app).get("/foo")

assert SimpleInitializableMiddleware.counter == 1

app = get_app()

test_client_factory(app).get("/foo")

assert SimpleInitializableMiddleware.counter == 2

0 comments on commit 3ae673b

Please sign in to comment.