Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shield send "http.response.start" from cancellation #1706

Closed
wants to merge 5 commits into from

Conversation

acjh
Copy link

@acjh acjh commented Jun 25, 2022

Fixes #1634

RuntimeError: No response returned. is raised in BaseHTTPMiddleware if request is disconnected, due to task_group.cancel_scope.cancel() in StreamingResponse.__call__.<locals>.wrap and cancellation check in await checkpoint() of MemoryObjectSendStream.send.

Let's fix this behaviour change caused by anyio integration in 0.15.0.


I managed to make this error reproducible in 0.14.2 by partially emulating 0.15.0 logic: acjh@37dd8ac

starlette/concurrency.py:

  async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
+     async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+         await func()
+         # (starlette 0.15.0) starlette.concurrency.run_until_first_complete `task_group.cancel_scope.cancel()`
+         for task in tasks:
+             if not task.done() and task != asyncio.current_task():
+                 task.cancel()
+
-     tasks = [create_task(handler(**kwargs)) for handler, kwargs in args]
+     tasks = [create_task(run(functools.partial(handler, **kwargs))) for handler, kwargs in args]
      (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
      [task.cancel() for task in pending]
-     [task.result() for task in done]
+     for task in done:
+         try:
+             task.result()
+         except asyncio.CancelledError:
+             pass

starlette/middleware/base.py:

  class BaseHTTPMiddleware:
      ...

      async def call_next(self, request: Request) -> Response:
          ...
-         send = queue.put
+
+         async def send(item: typing.Any) -> None:
+             await asyncio.sleep(0)  # anyio.streams.memory.MemoryObjectSendStream.send `await checkpoint()`
+             await queue.put(item)

          ...

`RuntimeError: No response returned.` is raised in BaseHTTPMiddleware
if request is disconnected, due to `task_group.cancel_scope.cancel()`
in StreamingResponse.__call__.<locals>.wrap and cancellation check in
`await checkpoint()` of MemoryObjectSendStream.send.

Let's fix this behaviour change caused by anyio integration in 0.15.0.
@adriangb
Copy link
Member

adriangb commented Jun 25, 2022

I think this can be fixed without shielding. This test fails on master but passes with this patch. Can you also try this test on your branch to see if it's a good test / if your branch fixes it?

diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 49a5e3e..5210d2d 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -4,7 +4,7 @@ import anyio
 
 from starlette.requests import Request
 from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 DispatchFunction = typing.Callable[
@@ -12,6 +12,10 @@ DispatchFunction = typing.Callable[
 ]
 
 
+class _ClientDisconnected(Exception):
+    pass
+
+
 class BaseHTTPMiddleware:
     def __init__(
         self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
@@ -28,12 +32,18 @@ class BaseHTTPMiddleware:
             app_exc: typing.Optional[Exception] = None
             send_stream, recv_stream = anyio.create_memory_object_stream()
 
+            async def recv() -> Message:
+                message = await request.receive()
+                if message["type"] == "http.disconnect":
+                    raise _ClientDisconnected
+                return message
+
             async def coro() -> None:
                 nonlocal app_exc
 
                 async with send_stream:
                     try:
-                        await self.app(scope, request.receive, send_stream.send)
+                        await self.app(scope, recv, send_stream.send)
                     except Exception as exc:
                         app_exc = exc
 
@@ -69,7 +79,10 @@ class BaseHTTPMiddleware:
 
         async with anyio.create_task_group() as task_group:
             request = Request(scope, receive=receive)
-            response = await self.dispatch_func(request, call_next)
+            try:
+                response = await self.dispatch_func(request, call_next)
+            except _ClientDisconnected:
+                return
             await response(scope, receive, send)
             task_group.cancel_scope.cancel()
 
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 976d77b..92826bc 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -1,13 +1,17 @@
 import contextvars
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Awaitable, Callable
 
+import anyio
 import pytest
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 
 class CustomMiddleware(BaseHTTPMiddleware):
@@ -206,3 +210,41 @@ def test_contextvars(test_client_factory, middleware_cls: type):
     client = test_client_factory(app)
     response = client.get("/")
     assert response.status_code == 200, response.content
+
+
+@pytest.mark.anyio
+async def test_client_disconnects_before_response_is_sent() -> None:
+    # test for https://github.com/encode/starlette/issues/1527
+    app: ASGIApp
+
+    async def homepage(request: Request):
+        await anyio.sleep(5)
+        return PlainTextResponse("hi!")
+
+    async def dispatch(
+        request: Request, call_next: Callable[[Request], Awaitable[Response]]
+    ) -> Response:
+        return await call_next(request)
+
+    app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch)
+    app = BaseHTTPMiddleware(app, dispatch=dispatch)
+
+    async def recv_gen() -> AsyncGenerator[Message, None]:
+        yield {"type": "http.request"}
+        yield {"type": "http.disconnect"}
+
+    async def send_gen() -> AsyncGenerator[None, Message]:
+        msg = yield
+        assert msg["type"] == "http.response.start"
+        msg = yield
+        raise AssertionError("Should not be called")
+
+    scope = {"type": "http", "method": "GET", "path": "/"}
+
+    async with AsyncExitStack() as stack:
+        recv = recv_gen()
+        stack.push_async_callback(recv.aclose)
+        send = send_gen()
+        stack.push_async_callback(send.aclose)
+        await send.__anext__()
+        await app(scope, recv.__aiter__().__anext__, send.asend)

@acjh
Copy link
Author

acjh commented Jun 26, 2022

My fix addresses the behaviour change in StreamingResponse caused by anyio integration.
Your fix pre-empts the behaviour of await recv_stream.receive() for client disconnection in BaseHTTPMiddleware itself.

That behaviour of StreamingResponse is not publicly stated; your fix is sufficient for issues of BaseHTTPMiddleware usage.

That test passes on my branch if recv_gen() yields "http.disconnect" twice, otherwise it raises StopAsyncIteration in listen_for_disconnect() for the second middleware. The uvicorn ASGI server will keep yielding "http.disconnect".

 async def recv_gen() -> AsyncGenerator[Message, None]:
     yield {"type": "http.request"}
     yield {"type": "http.disconnect"}
+    yield {"type": "http.disconnect"}

@adriangb
Copy link
Member

Apologies. I've been looking at BaseHTTPMiddleware a lot lately and got hung up on that 😅.

I asked in asgiref for confirmation on the expected behavior or ASGI servers w.r.t. sending the disconnect message multiple times. I think it would be a good idea to adapt that test (or just write a new one, up to you) to the specific situation this is supposed to fix. I think a test will be required before merging this.

@acjh
Copy link
Author

acjh commented Jun 27, 2022

I've added a test for the specific situation this is supposed to fix.

Actually, I don't think StreamingResponse should do this, since it's an intended feature of MemoryObjectSendStream.
Usages of StreamingResponse with MemoryObjectSendStream can wrap it in send if desired:

async def send(msg):
    with anyio.CancelScope(shield=True):
        await send_stream.send(msg)
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, request.receive, send)

I think it is probably preferable to do this in BaseHTTPMiddleware than in StreamingResponse.

I am also happy to close this PR in favour of the fix that you proposed in BaseHTTPMiddleware.

@adriangb
Copy link
Member

If that's all that's required in BaseHTTPMiddleware to fix this, I like that a lot, 1 LOC 🥳.

Also the barrier for doing something like this in BaseHTTPMiddleware is a lower: the fix is close the the source of the issue and BaseHTTPMiddleware already is dealing with streams, tasks, cancellation and such so adding some shielding isn't moving the needle too much on complexity.

@acjh
Copy link
Author

acjh commented Jun 27, 2022

Well, it's not actually 1 LOC 😅

I have submitted PR #1710 to shield send "http.response.start" from cancellation in BaseHTTPMiddleware.

@Kludex
Copy link
Sponsor Member

Kludex commented Oct 1, 2022

@Kludex Kludex closed this Oct 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Seemingly random error RuntimeError: No Reponse returned.
3 participants