Skip to content

Commit

Permalink
Use a proxy to modify middleware chain in place
Browse files Browse the repository at this point in the history
Superseeds encode#2017 and fixes encode#2017 (comment)
  • Loading branch information
adriangb committed Mar 30, 2023
1 parent bc90057 commit 453cb8e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 19 deletions.
75 changes: 58 additions & 17 deletions starlette/applications.py
Expand Up @@ -14,6 +14,18 @@
AppType = typing.TypeVar("AppType", bound="Starlette")


class _ASGIAppProxy:
"""A proxy into an ASGI app that we can re-assign to point to a new
app without modifying any references to the _ASGIAppProxy itself.
"""

def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)


class Starlette:
"""
Creates an application instance.
Expand Down Expand Up @@ -76,8 +88,18 @@ def __init__(
self.exception_handlers = (
{} if exception_handlers is None else dict(exception_handlers)
)
self._pending_exception_handlers = False
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: typing.Optional[ASGIApp] = None
self._pending_user_middlewares = self.user_middleware.copy()
# wrap ExceptionMiddleware in a proxy so that
# we can re-build it when exception handlers get added
self._exception_middleware = _ASGIAppProxy(
ExceptionMiddleware(self.router, debug=self.debug)
)

self._user_middleware_outer: _ASGIAppProxy | None = None
self._user_middleware_inner = _ASGIAppProxy(self._exception_middleware.app)
self.middleware_stack: ASGIApp | None = None

def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
Expand All @@ -92,20 +114,34 @@ def build_middleware_stack(self) -> ASGIApp:
else:
exception_handlers[key] = value

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
)
]
self._exception_middleware.app = ExceptionMiddleware(
self.router, handlers=exception_handlers, debug=debug
)
self._pending_exception_handlers = False

user_middleware = self._pending_user_middlewares.copy()
self._pending_user_middlewares.clear()

user_middleware_outer: ASGIApp
if user_middleware:
# build a new middleware chain that wraps self._exception_middleware
app: ASGIApp
app = new_user_middleware_inner = _ASGIAppProxy(self._exception_middleware)
for cls, options in reversed(user_middleware):
app = cls(app=app, **options)
# and set the .app for the previous innermost user middleware to point
# to this new chain
self._user_middleware_inner.app = app
# then replace our innermost user middleware
self._user_middleware_inner = new_user_middleware_inner

user_middleware_outer = self._user_middleware_outer or app
else:
user_middleware_outer = self._exception_middleware

return ServerErrorMiddleware(
app=user_middleware_outer, handler=error_handler, debug=debug
)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app

@property
def routes(self) -> typing.List[BaseRoute]:
Expand All @@ -117,8 +153,13 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
if (
self.middleware_stack is None
or self._pending_exception_handlers
or self._pending_user_middlewares
):
self.middleware_stack = self.build_middleware_stack()
assert self.middleware_stack is not None
await self.middleware_stack(scope, receive, send)

def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
Expand All @@ -135,15 +176,15 @@ def host(
self.router.host(host, app=app, name=name)

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self._pending_user_middlewares.append(Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, **options))

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
) -> None: # pragma: no cover
self._pending_exception_handlers = True
self.exception_handlers[exc_class_or_status_code] = handler

def add_event_handler(
Expand Down
2 changes: 1 addition & 1 deletion starlette/middleware/errors.py
Expand Up @@ -137,7 +137,7 @@ class ServerErrorMiddleware:
def __init__(
self,
app: ASGIApp,
handler: typing.Optional[typing.Callable] = None,
handler: typing.Optional[typing.Callable[..., typing.Any]] = None,
debug: bool = False,
) -> None:
self.app = app
Expand Down
49 changes: 48 additions & 1 deletion tests/test_applications.py
@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Callable
from typing import Any, AsyncIterator, Callable, Set

import anyio
import httpx
Expand Down Expand Up @@ -548,3 +548,50 @@ async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover
yield

App(lifespan=lifespan)


def test_lifespan_modifies_exc_handlers(
test_client_factory: Callable[[ASGIApp], httpx.Client]
):
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)

class SimpleInitializableMiddleware:
counter = 0

def __init__(self, app: ASGIApp):
self.app = app
instances.add(self)
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)

instances: Set[SimpleInitializableMiddleware] = set()

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
app.add_middleware(
SimpleInitializableMiddleware
) # should cause the stack to be re-built
yield

def get_app() -> ASGIApp:
app = Starlette(lifespan=lifespan)
app.add_middleware(SimpleInitializableMiddleware)
app.add_middleware(NoOpMiddleware)
return app

app = get_app()

with test_client_factory(app) as client:
assert len(instances) == 1
assert SimpleInitializableMiddleware.counter == 1
# next request rebuilds
client.get("/does-not-matter")
assert len(instances) == 2
assert SimpleInitializableMiddleware.counter == 2

0 comments on commit 453cb8e

Please sign in to comment.