Skip to content

Commit

Permalink
Fix issue with running middleware context manager
Browse files Browse the repository at this point in the history
Reported in #1678 (comment)
  • Loading branch information
jhominal committed Sep 3, 2022
1 parent 9161fb8 commit 274bdeb
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/middleware/test_base.py
@@ -1,4 +1,5 @@
import contextvars
from contextlib import AsyncExitStack

import anyio
import pytest
Expand Down Expand Up @@ -261,6 +262,69 @@ async def send(message):
assert background_task_run.is_set()


@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
context_manager_exited.set()

class ContextManagerMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async with AsyncExitStack() as stack:
stack.push_async_callback(sleep_and_set)
await self.app(scope, receive, send)

async def simple_endpoint(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[
Middleware(BaseHTTPMiddleware, dispatch=passthrough),
Middleware(ContextManagerMiddleware),
],
routes=[Route("/", simple_endpoint)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert context_manager_exited.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
Expand Down

0 comments on commit 274bdeb

Please sign in to comment.