Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a proxy object to modify middleware chain in place #2100

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 56 additions & 16 deletions starlette/applications.py
Expand Up @@ -14,6 +14,18 @@
AppType = typing.TypeVar("AppType", bound="Starlette")


class _ASGIAppProxy:
"""A proxy into an ASGI app that we can re-assign to point to a new
app without modifying any references to the _ASGIAppProxy itself.
"""

def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)


class Starlette:
"""
Creates an application instance.
Expand Down Expand Up @@ -78,6 +90,16 @@ def __init__(
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: typing.Optional[ASGIApp] = None
self._pending_exception_handlers = False
self._pending_user_middlewares = self.user_middleware.copy()
# wrap ExceptionMiddleware in a proxy so that
# we can re-build it when exception handlers get added
self._exception_middleware = _ASGIAppProxy(
ExceptionMiddleware(self.router, debug=self.debug)
)

self._user_middleware_outer: typing.Optional[_ASGIAppProxy] = None
self._user_middleware_inner = _ASGIAppProxy(self._exception_middleware.app)

def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
Expand All @@ -92,20 +114,34 @@ def build_middleware_stack(self) -> ASGIApp:
else:
exception_handlers[key] = value

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
)
]
self._exception_middleware.app = ExceptionMiddleware(
self.router, handlers=exception_handlers, debug=debug
)
self._pending_exception_handlers = False

user_middleware = self._pending_user_middlewares.copy()
self._pending_user_middlewares.clear()

user_middleware_outer: ASGIApp
if user_middleware:
# build a new middleware chain that wraps self._exception_middleware
app: ASGIApp
app = new_user_middleware_inner = _ASGIAppProxy(self._exception_middleware)
for cls, options in reversed(user_middleware):
app = cls(app=app, **options)
# and set the .app for the previous innermost user middleware to point
# to this new chain
self._user_middleware_inner.app = app
# then replace our innermost user middleware
self._user_middleware_inner = new_user_middleware_inner

user_middleware_outer = self._user_middleware_outer or app
else:
user_middleware_outer = self._exception_middleware

return ServerErrorMiddleware(
app=user_middleware_outer, handler=error_handler, debug=debug
)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app

@property
def routes(self) -> typing.List[BaseRoute]:
Expand All @@ -117,7 +153,11 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
if (
self.middleware_stack is None
or self._pending_exception_handlers
or self._pending_user_middlewares
):
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)

Expand All @@ -135,15 +175,15 @@ def host(
self.router.host(host, app=app, name=name)

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self._pending_user_middlewares.append(Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, **options))

def add_exception_handler(
self,
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
handler: typing.Callable,
) -> None: # pragma: no cover
self._pending_exception_handlers = True
self.exception_handlers[exc_class_or_status_code] = handler

def add_event_handler(
Expand Down
59 changes: 58 additions & 1 deletion tests/test_applications.py
@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Callable
from typing import Any, AsyncIterator, Callable, List, Set

import anyio
import httpx
Expand Down Expand Up @@ -548,3 +548,60 @@ async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover
yield

App(lifespan=lifespan)


def test_lifespan_modifies_exc_handlers(
test_client_factory: Callable[[ASGIApp], httpx.Client]
):
calls: List[str] = []

class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
calls.append("NoOpMiddleware")
await self.app(*args)

class SimpleInitializableMiddleware:
counter = 0

def __init__(self, app: ASGIApp, value: int):
self.app = app
self.value = value
instances.add(self)
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
calls.append(f"SimpleInitializableMiddleware({self.value})")
await self.app(*args)

instances: Set[SimpleInitializableMiddleware] = set()

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
app.add_middleware(
SimpleInitializableMiddleware, value=2
) # should cause the stack to be re-built
yield

def get_app() -> ASGIApp:
app = Starlette(lifespan=lifespan)
app.add_middleware(SimpleInitializableMiddleware, value=1)
app.add_middleware(NoOpMiddleware)
return app

app = get_app()

with test_client_factory(app) as client:
assert len(instances) == 1
assert SimpleInitializableMiddleware.counter == 1
# next request rebuilds
client.get("/does-not-matter")
assert len(instances) == 2
assert SimpleInitializableMiddleware.counter == 2
assert calls == [
"SimpleInitializableMiddleware(1)",
"NoOpMiddleware",
"SimpleInitializableMiddleware(2)",
]