diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..721a2eb58 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ @@ -28,12 +28,26 @@ async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() + async def send(msg: Message) -> None: + # Shield send "http.response.start" from cancellation. + # Otherwise, `await recv_stream.receive()` will raise + # `anyio.EndOfStream` if the connection is disconnected, + # due to `task_group.cancel_scope.cancel()` in + # `StreamingResponse.__call__..wrap` + # and cancellation check during `await checkpoint()` in + # `MemoryObjectSendStream.send`. + # This would trigger the check we have in this middleware resulting in + # `RuntimeError: No response returned.` being raised below. + shield = msg["type"] == "http.response.start" + with anyio.CancelScope(shield=shield): + await send_stream.send(msg) + async def coro() -> None: nonlocal app_exc async with send_stream: try: - await self.app(scope, request.receive, send_stream.send) + await self.app(scope, request.receive, send) except Exception as exc: app_exc = exc diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 976d77b86..80087686d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,13 +1,16 @@ import contextvars +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Awaitable, Callable import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send class CustomMiddleware(BaseHTTPMiddleware): @@ -206,3 +209,41 @@ async def homepage(request): client = test_client_factory(app) response = client.get("/") assert response.status_code == 200, response.content + + +@pytest.mark.anyio +async def test_client_disconnects_before_response_is_sent() -> None: + app: ASGIApp + + async def homepage(request: Request): + # await anyio.sleep(5) + return PlainTextResponse("hi!") + + async def dispatch( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + return await call_next(request) + + app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch) + app = BaseHTTPMiddleware(app, dispatch=dispatch) + + async def recv_gen() -> AsyncGenerator[Message, None]: + yield {"type": "http.request"} + yield {"type": "http.disconnect"} + yield {"type": "http.disconnect"} + + async def send_gen() -> AsyncGenerator[None, Message]: + msg = yield + assert msg["type"] == "http.response.start" + msg = yield + raise AssertionError("Should not be called") # pragma: no cover + + scope = {"type": "http", "method": "GET", "path": "/"} + + async with AsyncExitStack() as stack: + recv = recv_gen() + stack.push_async_callback(recv.aclose) + send = send_gen() + stack.push_async_callback(send.aclose) + await send.__anext__() + await app(scope, recv.__aiter__().__anext__, send.asend)