From 5a9b41475ae1f54942ee67f90154f5da8f36e117 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 23 Apr 2022 23:48:50 -0500 Subject: [PATCH] Document interaction of BaseHTTPMiddleware and contextvars (#1525) * test: document behavior of ContextVars with BaseHTTPMiddleware * lint & fix * add pragma * Update test_base.py * Update tests/middleware/test_base.py Co-authored-by: Marcelo Trylesinski * fix typo * try to make comment clearer Co-authored-by: Marcelo Trylesinski --- tests/middleware/test_base.py | 58 +++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 04da3a961..0d023ddd1 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,3 +1,5 @@ +import contextvars + import pytest from starlette.applications import Starlette @@ -5,6 +7,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Mount, Route, WebSocketRoute +from starlette.types import ASGIApp, Receive, Scope, Send class CustomMiddleware(BaseHTTPMiddleware): @@ -163,3 +166,58 @@ def test_exception_on_mounted_apps(test_client_factory): with pytest.raises(Exception) as ctx: client.get("/sub/") assert str(ctx.value) == "Exc" + + +ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + + +class CustomMiddlewareWithoutBaseHTTPMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ctxvar.set("set by middleware") + await self.app(scope, receive, send) + assert ctxvar.get() == "set by endpoint" + + +class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + ctxvar.set("set by middleware") + resp = await call_next(request) + assert ctxvar.get() == "set by endpoint" + return resp # pragma: no cover + + +@pytest.mark.parametrize( + "middleware_cls", + [ + CustomMiddlewareWithoutBaseHTTPMiddleware, + pytest.param( + CustomMiddlewareUsingBaseHTTPMiddleware, + marks=pytest.mark.xfail( + reason=( + "BaseHTTPMiddleware creates a TaskGroup which copies the context" + "and erases any changes to it made within the TaskGroup" + ), + raises=AssertionError, + ), + ), + ], +) +def test_contextvars(test_client_factory, middleware_cls: type): + # this has to be an async endpoint because Starlette calls run_in_threadpool + # on sync endpoints which has it's own set of peculiarities w.r.t propagating + # contextvars (it propagates them forwards but not backwards) + async def homepage(request): + assert ctxvar.get() == "set by middleware" + ctxvar.set("set by endpoint") + return PlainTextResponse("Homepage") + + app = Starlette( + middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 200, response.content