Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document interaction of BaseHTTPMiddleware and contextvars #1525

Merged
Merged
58 changes: 58 additions & 0 deletions tests/middleware/test_base.py
@@ -1,10 +1,13 @@
import contextvars

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -163,3 +166,58 @@ def test_exception_on_mounted_apps(test_client_factory):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


class CustomMiddlewareWithoutBaseHTTPMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"


class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
ctxvar.set("set by middleware")
resp = await call_next(request)
assert ctxvar.get() == "set by endpoint"
return resp # pragma: no cover


@pytest.mark.parametrize(
"middleware_cls",
[
CustomMiddlewareWithoutBaseHTTPMiddleware,
pytest.param(
CustomMiddlewareUsingBaseHTTPMiddleware,
marks=pytest.mark.xfail(
reason=(
"BaseHTTPMiddleware creates a TaskGroup which copies the context"
"and erases any changes to it made within the TaskGroup"
),
raises=AssertionError,
),
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
# this has to be an async endpoint because Starlette calls run_in_thredpool
adriangb marked this conversation as resolved.
Show resolved Hide resolved
# on sync endpoints which suffers from the same problem of erasing changes
# to the context
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What "same problem"? You mean the same as caused by the TaskGroup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Although I think this has been fixed in the meantime. Let me re-confirm.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I think that should fix the issue. It doesn't change the rest of this PR, but we can just remove my comment.

Copy link
Member Author

@adriangb adriangb Apr 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay, no the issue is different: to_thread.run_sync() now propagates the context _ forwards_ but not backwards:

from contextvars import ContextVar
import anyio

ctx = ContextVar[str]("ctx")

async def async_func() -> None:
    assert ctx.set("bar")

def sync_func() -> None:
    assert ctx.set("foo")

async def main() -> None:
    await async_func()
    assert ctx.get() == "bar"
    await anyio.to_thread.run_sync(sync_func)
    assert ctx.get() == "foo"  # fails

anyio.run(main)

So there is still a subtle difference between a async and async endpoint. Side note: I think at some point before a 1.0 support for sync endpoints should be dropped, I think @tomchristie suggested this as well in the past via Gitter (Tom if you recall any of this conversation, or don't think it happened, please correct me if I am wrong).

I'm also happy to add a similar test in another PR documenting that edge case.

async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")

app = Starlette(
middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content