Skip to content

Commit

Permalink
✨ Update internal AsyncExitStack to fix context for dependencies wi…
Browse files Browse the repository at this point in the history
…th `yield` (tiangolo#4575)
  • Loading branch information
tiangolo authored and JeanArhancet committed Aug 20, 2022
1 parent b176778 commit c0a5e28
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 16 deletions.
10 changes: 6 additions & 4 deletions docs/en/docs/tutorial/dependencies/dependencies-with-yield.md
Expand Up @@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca

It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.

The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).

So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.

Expand Down Expand Up @@ -138,9 +138,11 @@ participant tasks as Background tasks
end
dep ->> operation: Run dependency, e.g. DB session
opt raise
operation -->> handler: Raise HTTPException
operation -->> dep: Raise HTTPException
dep -->> handler: Auto forward exception
handler -->> client: HTTP error response
operation -->> dep: Raise other exception
dep -->> handler: Auto forward exception
end
operation ->> client: Return response to client
Note over client,operation: Response is already sent, can't change it anymore
Expand All @@ -162,9 +164,9 @@ participant tasks as Background tasks
After one of those responses is sent, no other response can be sent.

!!! tip
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.

But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.

## Context Managers

Expand Down
61 changes: 53 additions & 8 deletions fastapi/applications.py
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union

from fastapi import routing
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
Expand All @@ -11,6 +10,7 @@
)
from fastapi.exceptions import RequestValidationError
from fastapi.logger import logger
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
Expand All @@ -21,8 +21,9 @@
from fastapi.types import DecoratedCallable
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.middleware import Middleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
Expand Down Expand Up @@ -134,6 +135,55 @@ def __init__(
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()

def build_middleware_stack(self) -> ASGIApp:
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
# inside of ExceptionMiddleware, inside of custom user middlewares
debug = self.debug
error_handler = None
exception_handlers = {}

for key, value in self.exception_handlers.items():
if key in (500, Exception):
error_handler = value
else:
exception_handlers[key] = value

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
+ [
Middleware(
ExceptionMiddleware, handlers=exception_handlers, debug=debug
),
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
# contextvars.
# This needs to happen after user middlewares because those create a
# new contextvars context copy by using a new AnyIO task group.
# The initial part of dependencies with yield is executed in the
# FastAPI code, inside all the middlewares, but the teardown part
# (after yield) is executed in the AsyncExitStack in this middleware,
# if the AsyncExitStack lived outside of the custom middlewares and
# contextvars were set in a dependency with yield in that internal
# contextvars context, the values would not be available in the
# outside context of the AsyncExitStack.
# By putting the middleware and the AsyncExitStack here, inside all
# user middlewares, the code before and after yield in dependencies
# with yield is executed in the same contextvars context, so all values
# set in contextvars before yield is still available after yield as
# would be expected.
# Additionally, by having this AsyncExitStack here, after the
# ExceptionMiddleware, now dependencies can catch handled exceptions,
# e.g. HTTPException, to customize the teardown code (e.g. DB session
# rollback).
Middleware(AsyncExitStackMiddleware),
]
)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app

def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
Expand Down Expand Up @@ -206,12 +256,7 @@ async def redoc_html(req: Request) -> HTMLResponse:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.root_path:
scope["root_path"] = self.root_path
if AsyncExitStack:
async with AsyncExitStack() as stack:
scope["fastapi_astack"] = stack
await super().__call__(scope, receive, send)
else:
await super().__call__(scope, receive, send) # pragma: no cover
await super().__call__(scope, receive, send)

def add_api_route(
self,
Expand Down
28 changes: 28 additions & 0 deletions fastapi/middleware/asyncexitstack.py
@@ -0,0 +1,28 @@
from typing import Optional

from fastapi.concurrency import AsyncExitStack
from starlette.types import ASGIApp, Receive, Scope, Send


class AsyncExitStackMiddleware:
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
self.app = app
self.context_name = context_name

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if AsyncExitStack:
dependency_exception: Optional[Exception] = None
async with AsyncExitStack() as stack:
scope[self.context_name] = stack
try:
await self.app(scope, receive, send)
except Exception as e:
dependency_exception = e
raise e
if dependency_exception:
# This exception was possibly handled by the dependency but it should
# still bubble up so that the ServerErrorMiddleware can return a 500
# or the ExceptionMiddleware can catch and handle any other exceptions
raise dependency_exception
else:
await self.app(scope, receive, send) # pragma: no cover
44 changes: 40 additions & 4 deletions tests/test_dependency_contextmanager.py
Expand Up @@ -235,7 +235,16 @@ def test_sync_raise_other():
assert "/sync_raise" not in errors


def test_async_raise():
def test_async_raise_raises():
with pytest.raises(AsyncDependencyError):
client.get("/async_raise")
assert state["/async_raise"] == "asyncgen raise finalized"
assert "/async_raise" in errors
errors.clear()


def test_async_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/async_raise")
assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized"
Expand Down Expand Up @@ -270,7 +279,16 @@ def test_background_tasks():
assert state["bg"] == "bg set - b: started b - a: started a"


def test_sync_raise():
def test_sync_raise_raises():
with pytest.raises(SyncDependencyError):
client.get("/sync_raise")
assert state["/sync_raise"] == "generator raise finalized"
assert "/sync_raise" in errors
errors.clear()


def test_sync_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_raise")
assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized"
Expand Down Expand Up @@ -306,15 +324,33 @@ def test_sync_sync_raise_other():
assert "/sync_raise" not in errors


def test_sync_async_raise():
def test_sync_async_raise_raises():
with pytest.raises(AsyncDependencyError):
client.get("/sync_async_raise")
assert state["/async_raise"] == "asyncgen raise finalized"
assert "/async_raise" in errors
errors.clear()


def test_sync_async_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_async_raise")
assert response.status_code == 500, response.text
assert state["/async_raise"] == "asyncgen raise finalized"
assert "/async_raise" in errors
errors.clear()


def test_sync_sync_raise():
def test_sync_sync_raise_raises():
with pytest.raises(SyncDependencyError):
client.get("/sync_sync_raise")
assert state["/sync_raise"] == "generator raise finalized"
assert "/sync_raise" in errors
errors.clear()


def test_sync_sync_raise_server_error():
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/sync_sync_raise")
assert response.status_code == 500, response.text
assert state["/sync_raise"] == "generator raise finalized"
Expand Down
51 changes: 51 additions & 0 deletions tests/test_dependency_contextvars.py
@@ -0,0 +1,51 @@
from contextvars import ContextVar
from typing import Any, Awaitable, Callable, Dict, Optional

from fastapi import Depends, FastAPI, Request, Response
from fastapi.testclient import TestClient

legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"legacy_request_state_context_var", default=None
)

app = FastAPI()


async def set_up_request_state_dependency():
request_state = {"user": "deadpond"}
contextvar_token = legacy_request_state_context_var.set(request_state)
yield request_state
legacy_request_state_context_var.reset(contextvar_token)


@app.middleware("http")
async def custom_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["custom"] = "foo"
return response


@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
def get_user():
request_state = legacy_request_state_context_var.get()
assert request_state
return request_state["user"]


client = TestClient(app)


def test_dependency_contextvars():
"""
Check that custom middlewares don't affect the contextvar context for dependencies.
The code before yield and the code after yield should be run in the same contextvar
context, so that request_state_context_var.reset(contextvar_token).
If they are run in a different context, that raises an error.
"""
response = client.get("/user")
assert response.json() == "deadpond"
assert response.headers["custom"] == "foo"
71 changes: 71 additions & 0 deletions tests/test_dependency_normal_exceptions.py
@@ -0,0 +1,71 @@
import pytest
from fastapi import Body, Depends, FastAPI, HTTPException
from fastapi.testclient import TestClient

initial_fake_database = {"rick": "Rick Sanchez"}

fake_database = initial_fake_database.copy()

initial_state = {"except": False, "finally": False}

state = initial_state.copy()

app = FastAPI()


async def get_database():
temp_database = fake_database.copy()
try:
yield temp_database
fake_database.update(temp_database)
except HTTPException:
state["except"] = True
finally:
state["finally"] = True


@app.put("/invalid-user/{user_id}")
def put_invalid_user(
user_id: str, name: str = Body(...), db: dict = Depends(get_database)
):
db[user_id] = name
raise HTTPException(status_code=400, detail="Invalid user")


@app.put("/user/{user_id}")
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
db[user_id] = name
return {"message": "OK"}


@pytest.fixture(autouse=True)
def reset_state_and_db():
global fake_database
global state
fake_database = initial_fake_database.copy()
state = initial_state.copy()


client = TestClient(app)


def test_dependency_gets_exception():
assert state["except"] is False
assert state["finally"] is False
response = client.put("/invalid-user/rick", json="Morty")
assert response.status_code == 400, response.text
assert response.json() == {"detail": "Invalid user"}
assert state["except"] is True
assert state["finally"] is True
assert fake_database["rick"] == "Rick Sanchez"


def test_dependency_no_exception():
assert state["except"] is False
assert state["finally"] is False
response = client.put("/user/rick", json="Morty")
assert response.status_code == 200, response.text
assert response.json() == {"message": "OK"}
assert state["except"] is False
assert state["finally"] is True
assert fake_database["rick"] == "Morty"

0 comments on commit c0a5e28

Please sign in to comment.