Skip to content

Commit

Permalink
Document interaction of BaseHTTPMiddleware and contextvars (#1525)
Browse files Browse the repository at this point in the history
* test: document behavior of ContextVars with BaseHTTPMiddleware

* lint & fix

* add pragma

* Update test_base.py

* Update tests/middleware/test_base.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>

* fix typo

* try to make comment clearer

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
adriangb and Kludex committed Apr 24, 2022
1 parent ce0709d commit 5a9b414
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tests/middleware/test_base.py
@@ -1,10 +1,13 @@
import contextvars

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -163,3 +166,58 @@ def test_exception_on_mounted_apps(test_client_factory):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


class CustomMiddlewareWithoutBaseHTTPMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"


class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
ctxvar.set("set by middleware")
resp = await call_next(request)
assert ctxvar.get() == "set by endpoint"
return resp # pragma: no cover


@pytest.mark.parametrize(
"middleware_cls",
[
CustomMiddlewareWithoutBaseHTTPMiddleware,
pytest.param(
CustomMiddlewareUsingBaseHTTPMiddleware,
marks=pytest.mark.xfail(
reason=(
"BaseHTTPMiddleware creates a TaskGroup which copies the context"
"and erases any changes to it made within the TaskGroup"
),
raises=AssertionError,
),
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")

app = Starlette(
middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content

0 comments on commit 5a9b414

Please sign in to comment.