Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shield send "http.response.start" from cancellation (BaseHTTPMiddleware) #1710

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 12 additions & 2 deletions starlette/middleware/base.py
Expand Up @@ -4,7 +4,7 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
Expand All @@ -28,12 +28,22 @@ async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def send(msg: Message) -> None:
# Shield send "http.response.start" from cancellation.
# Otherwise, `await recv_stream.receive()` will raise `anyio.EndOfStream` if request is disconnected,
# due to `task_group.cancel_scope.cancel()` in `StreamingResponse.__call__.<locals>.wrap`
# and cancellation check in `await checkpoint()` of `MemoryObjectSendStream.send`,
# and then `RuntimeError: No response returned.` will be raised below.
acjh marked this conversation as resolved.
Show resolved Hide resolved
shield = msg["type"] == "http.response.start"
with anyio.CancelScope(shield=shield):
await send_stream.send(msg)

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, request.receive, send)
except Exception as exc:
app_exc = exc

Expand Down
45 changes: 43 additions & 2 deletions tests/middleware/test_base.py
@@ -1,13 +1,16 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Awaitable, Callable

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -206,3 +209,41 @@ async def homepage(request):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content


@pytest.mark.anyio
async def test_client_disconnects_before_response_is_sent() -> None:
app: ASGIApp

async def homepage(request: Request):
# await anyio.sleep(5)
return PlainTextResponse("hi!")

async def dispatch(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch)
app = BaseHTTPMiddleware(app, dispatch=dispatch)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
yield {"type": "http.disconnect"}
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
msg = yield
assert msg["type"] == "http.response.start"
msg = yield
raise AssertionError("Should not be called")

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)