From 89fab50623a3ca7db81ff8ddd0d482f88e2024df Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Mon, 20 Jun 2022 18:01:09 +0900 Subject: [PATCH 1/7] Remove unexpected task cancellation --- starlette/middleware/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..6c64421ec 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -71,7 +71,6 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) - task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint From d7762c2303d861c2aac84beb789fba90573f3329 Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Mon, 20 Jun 2022 19:54:43 +0900 Subject: [PATCH 2/7] Update base.py --- starlette/middleware/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 6c64421ec..0356a7290 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -33,12 +33,12 @@ async def coro() -> None: async with send_stream: try: + # may block in send await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc - task_group.start_soon(coro) - + task_group.start_soon(coro, name="__call_next") try: message = await recv_stream.receive() except anyio.EndOfStream: @@ -72,6 +72,12 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response = await self.dispatch_func(request, call_next) await response(scope, receive, send) + t = anyio.get_current_task() + if t.name == "anyio.from_thread.BlockingPortal._call_func": + # cancel stuck task due to discarded response + # see: https://github.com/encode/starlette/issues/1022 + task_group.cancel_scope.cancel() + async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: From 584d5819831c89621fe6ba9a898828a554a3db01 Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Tue, 21 Jun 2022 15:46:15 +0900 Subject: [PATCH 3/7] Shield background task --- starlette/background.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/starlette/background.py b/starlette/background.py index 4aaf7ae3c..db9b38af8 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,6 +1,8 @@ import sys import typing +import anyio + if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover @@ -22,10 +24,11 @@ def __init__( self.is_async = is_async_callable(func) async def __call__(self) -> None: - if self.is_async: - await self.func(*self.args, **self.kwargs) - else: - await run_in_threadpool(self.func, *self.args, **self.kwargs) + with anyio.CancelScope(shield=True): + if self.is_async: + await self.func(*self.args, **self.kwargs) + else: + await run_in_threadpool(self.func, *self.args, **self.kwargs) class BackgroundTasks(BackgroundTask): From 84a91b856c3eb486a1fae0a2380372804cb1c664 Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Tue, 21 Jun 2022 15:48:08 +0900 Subject: [PATCH 4/7] Update base.py --- starlette/middleware/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 0356a7290..b77bb41f9 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -33,12 +33,13 @@ async def coro() -> None: async with send_stream: try: - # may block in send + # may block in send if body_stream not consumed await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc - task_group.start_soon(coro, name="__call_next") + task_group.start_soon(coro) + try: message = await recv_stream.receive() except anyio.EndOfStream: From 0a4938238de0a48c24107622c5f2b42ae36404dc Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Tue, 21 Jun 2022 15:50:58 +0900 Subject: [PATCH 5/7] Add background task test --- tests/middleware/test_base.py | 47 ++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 976d77b86..1c1de93c4 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,11 +1,14 @@ import contextvars +import traceback +import anyio import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTasks from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send @@ -84,6 +87,31 @@ def test_custom_middleware(test_client_factory): assert text == "Hello, world!" +def test_background_tasks(test_client_factory): + async def _sleep(identifier, delay): + print(identifier, "started") + try: + await anyio.sleep(delay) + print(identifier, "completed") + except BaseException: + print(identifier, "error") + traceback.print_exc() + raise + + async def bg_task(request): + background_tasks = BackgroundTasks() + background_tasks.add_task(_sleep, "background task 1", 2) + background_tasks.add_task(_sleep, "background task 2", 2) + return Response(background=background_tasks) + + app = Starlette( + routes=[Route("/bg-task", bg_task)], middleware=[Middleware(CustomMiddleware)] + ) + client = test_client_factory(app) + response = client.get("/bg-task") + assert response.text == "" + + def test_state_data_across_multiple_middlewares(test_client_factory): expected_value1 = "foo" expected_value2 = "bar" @@ -141,12 +169,25 @@ def homepage(request): def test_fully_evaluated_response(test_client_factory): # Test for https://github.com/encode/starlette/issues/1022 - class CustomMiddleware(BaseHTTPMiddleware): + class ConsumeMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + resp = await call_next(request) + + async def _send(m): + pass + + await resp.stream_response(_send) # type: ignore + + return PlainTextResponse("Custom") + + class DiscardMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): await call_next(request) return PlainTextResponse("Custom") - app = Starlette(middleware=[Middleware(CustomMiddleware)]) + app = Starlette( + middleware=[Middleware(ConsumeMiddleware), Middleware(DiscardMiddleware)] + ) client = test_client_factory(app) response = client.get("/does_not_exist") From 37852175d9016da37ccaa10eeafef7095dc28270 Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Tue, 21 Jun 2022 16:03:15 +0900 Subject: [PATCH 6/7] Update test_base.py Fix coverage --- tests/middleware/test_base.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 1c1de93c4..98ce5e8bf 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,5 +1,4 @@ import contextvars -import traceback import anyio import pytest @@ -90,13 +89,8 @@ def test_custom_middleware(test_client_factory): def test_background_tasks(test_client_factory): async def _sleep(identifier, delay): print(identifier, "started") - try: - await anyio.sleep(delay) - print(identifier, "completed") - except BaseException: - print(identifier, "error") - traceback.print_exc() - raise + await anyio.sleep(delay) + print(identifier, "completed") async def bg_task(request): background_tasks = BackgroundTasks() From 02aa6eb9588f64c0be331ea118598fdc7d13b6d5 Mon Sep 17 00:00:00 2001 From: Weiliang Li Date: Wed, 22 Jun 2022 13:40:16 +0900 Subject: [PATCH 7/7] Shield bg task in Response --- starlette/background.py | 11 ++++------- starlette/middleware/base.py | 9 ++------- starlette/responses.py | 16 ++++++++++------ tests/middleware/test_base.py | 8 ++++++++ 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/starlette/background.py b/starlette/background.py index db9b38af8..4aaf7ae3c 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,8 +1,6 @@ import sys import typing -import anyio - if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover @@ -24,11 +22,10 @@ def __init__( self.is_async = is_async_callable(func) async def __call__(self) -> None: - with anyio.CancelScope(shield=True): - if self.is_async: - await self.func(*self.args, **self.kwargs) - else: - await run_in_threadpool(self.func, *self.args, **self.kwargs) + if self.is_async: + await self.func(*self.args, **self.kwargs) + else: + await run_in_threadpool(self.func, *self.args, **self.kwargs) class BackgroundTasks(BackgroundTask): diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b77bb41f9..2ea98a4bc 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -33,7 +33,7 @@ async def coro() -> None: async with send_stream: try: - # may block in send if body_stream not consumed + # `send_stream.send` blocks until `body_stream` is called await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc @@ -72,12 +72,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) - - t = anyio.get_current_task() - if t.name == "anyio.from_thread.BlockingPortal._call_func": - # cancel stuck task due to discarded response - # see: https://github.com/encode/starlette/issues/1022 - task_group.cancel_scope.cancel() + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/responses.py b/starlette/responses.py index 98c8caf1b..c15fd2848 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -166,8 +166,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) await send({"type": "http.response.body", "body": self.body}) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class HTMLResponse(Response): @@ -264,8 +265,9 @@ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: task_group.start_soon(wrap, partial(self.stream_response, send)) await wrap(partial(self.listen_for_disconnect, receive)) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class FileResponse(Response): @@ -350,5 +352,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "more_body": more_body, } ) - if self.background is not None: - await self.background() + + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 98ce5e8bf..c7b65e8e4 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -92,10 +92,18 @@ async def _sleep(identifier, delay): await anyio.sleep(delay) print(identifier, "completed") + def _sleep_sync(identifier, delay): + import time + + print(identifier, "started") + time.sleep(delay) + print(identifier, "completed") + async def bg_task(request): background_tasks = BackgroundTasks() background_tasks.add_task(_sleep, "background task 1", 2) background_tasks.add_task(_sleep, "background task 2", 2) + background_tasks.add_task(_sleep_sync, "background task sync", 2) return Response(background=background_tasks) app = Starlette(