diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 8c110ca5d..ed0734bd3 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,4 +1,5 @@ import contextvars +from contextlib import AsyncExitStack import anyio import pytest @@ -261,6 +262,69 @@ async def send(message): assert background_task_run.is_set() +@pytest.mark.anyio +async def test_run_context_manager_exit_even_if_client_disconnects(): + # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 + request_body_sent = False + response_complete = anyio.Event() + context_manager_exited = anyio.Event() + + async def sleep_and_set(): + # small delay to give BaseHTTPMiddleware a chance to cancel us + # this is required to make the test fail prior to fixing the issue + # so do not be surprised if you remove it and the test still passes + await anyio.sleep(0.1) + context_manager_exited.set() + + class ContextManagerMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + async with AsyncExitStack() as stack: + stack.push_async_callback(sleep_and_set) + await self.app(scope, receive, send) + + async def simple_endpoint(_): + return PlainTextResponse(background=BackgroundTask(sleep_and_set)) + + async def passthrough(request, call_next): + return await call_next(request) + + app = Starlette( + middleware=[ + Middleware(BaseHTTPMiddleware, dispatch=passthrough), + Middleware(ContextManagerMiddleware), + ], + routes=[Route("/", simple_endpoint)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive(): + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + # We simulate a client that disconnects immediately after receiving the response + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message): + if message["type"] == "http.response.body": + if not message.get("more_body", False): + response_complete.set() + + await app(scope, receive, send) + + assert context_manager_exited.is_set() + + def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory): class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next):