Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Update internal AsyncExitStack to fix context for dependencies with yield #4575

Merged
merged 12 commits into from Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 34 additions & 8 deletions fastapi/applications.py
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union

from fastapi import routing
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
Expand All @@ -11,6 +10,7 @@
)
from fastapi.exceptions import RequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
Expand All @@ -21,8 +21,9 @@
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware import Middleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
Expand Down Expand Up @@ -134,6 +135,36 @@ def __init__(
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()

def build_middleware_stack(self) -> ASGIApp:
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
# inside of ExceptionMiddleware, inside of custom user middlewares
debug = self.debug
error_handler = None
exception_handlers = {}

for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
),
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies
Middleware(AsyncExitStackMiddleware),
]
)

app = self.router
for cls, options in reversed(middleware):
tiangolo marked this conversation as resolved.
Show resolved Hide resolved
app = cls(app=app, **options)
return app

def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
Expand Down Expand Up @@ -206,12 +237,7 @@ async def redoc_html(req: Request) -> HTMLResponse:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.root_path:
scope["root_path"] = self.root_path
if AsyncExitStack:
async with AsyncExitStack() as stack:
scope["fastapi_astack"] = stack
await super().__call__(scope, receive, send)
else:
await super().__call__(scope, receive, send) # pragma: no cover
await super().__call__(scope, receive, send)

def add_api_route(
self,
Expand Down
16 changes: 16 additions & 0 deletions fastapi/middleware/asyncexitstack.py
@@ -0,0 +1,16 @@
from fastapi.concurrency import AsyncExitStack
from starlette.types import ASGIApp, Receive, Scope, Send


class AsyncExitStackMiddleware:
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
self.app = app
self.context_name = context_name

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if AsyncExitStack:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that AsyncExitStack is always true, it’s imported and concurrency provides it from either of two locations, but it’s always provided.

Copy link
Sponsor Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah... 🤔

async with AsyncExitStack() as stack:
scope[self.context_name] = stack
await self.app(scope, receive, send)
else:
await self.app(scope, receive, send) # pragma: no cover
51 changes: 51 additions & 0 deletions tests/test_dependency_contextvars.py
@@ -0,0 +1,51 @@
from contextvars import ContextVar
from typing import Any, Awaitable, Callable, Dict, Optional

from fastapi import Depends, FastAPI, Request, Response
from fastapi.testclient import TestClient

legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"legacy_request_state_context_var", default=None
)

app = FastAPI()


async def set_up_request_state_dependency():
request_state = {"user": "deadpond"}
contextvar_token = legacy_request_state_context_var.set(request_state)
yield request_state
legacy_request_state_context_var.reset(contextvar_token)


@app.middleware("http")
async def custom_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["custom"] = "foo"
return response


@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
def get_user():
request_state = legacy_request_state_context_var.get()
assert request_state
return request_state["user"]


client = TestClient(app)


def test_dependency_contextvars():
"""
Check that custom middlewares don't affect the contextvar context for dependencies.

The code before yield and the code after yield should be run in the same contextvar
context, so that request_state_context_var.reset(contextvar_token).

If they are run in a different context, that raises an error.
"""
response = client.get("/user")
assert response.json() == "deadpond"
assert response.headers["custom"] == "foo"