From c98203edcb4e210f630f5a39e2ed453561b537a0 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Wed, 29 Jun 2022 23:59:41 +0200 Subject: [PATCH] prototype design where BaseHTTPMiddleware works without cancelling downstream app --- starlette/middleware/base.py | 53 ++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d7..39828198b7 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,13 +4,23 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] +T = typing.TypeVar("T") + + +async def _call_and_cancel( + func: typing.Callable[[], typing.Awaitable[T]], cancel_scope: anyio.CancelScope +) -> T: + result = await func() + cancel_scope.cancel() + return result + class BaseHTTPMiddleware: def __init__( @@ -24,19 +34,56 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + app_disconnected = anyio.Event() + async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() + async def receive_or_disconnect() -> Message: + if app_disconnected.is_set(): + return {"type": "http.disconnect"} + + async with anyio.create_task_group() as task_group: + # app_disconnected is set, cancel parent scope to cancel + # request.receive + task_group.start_soon( + _call_and_cancel, + app_disconnected.wait, + task_group.cancel_scope, + ) + # if request.receive returns a message, cancel the task_group to + # exit its block + message = await _call_and_cancel( + request.receive, task_group.cancel_scope + ) + + if app_disconnected.is_set(): + return {"type": "http.disconnect"} + + return message + + async def close_recv_stream_on_disconnect() -> None: + await app_disconnected.wait() + recv_stream.close() + + async def send_no_error(message: Message) -> None: + try: + await send_stream.send(message) + except anyio.BrokenResourceError: + # recv_stream has been closed, i.e. app_disconnected has been set. + return + async def coro() -> None: nonlocal app_exc async with send_stream: try: - await self.app(scope, request.receive, send_stream.send) + await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc + task_group.start_soon(close_recv_stream_on_disconnect) task_group.start_soon(coro) try: @@ -71,7 +118,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) - task_group.cancel_scope.cancel() + app_disconnected.set() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint