diff --git a/starlette/background.py b/starlette/background.py index db9b38af8..4aaf7ae3c 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,8 +1,6 @@ import sys import typing -import anyio - if sys.version_info >= (3, 10): # pragma: no cover from typing import ParamSpec else: # pragma: no cover @@ -24,11 +22,10 @@ def __init__( self.is_async = is_async_callable(func) async def __call__(self) -> None: - with anyio.CancelScope(shield=True): - if self.is_async: - await self.func(*self.args, **self.kwargs) - else: - await run_in_threadpool(self.func, *self.args, **self.kwargs) + if self.is_async: + await self.func(*self.args, **self.kwargs) + else: + await run_in_threadpool(self.func, *self.args, **self.kwargs) class BackgroundTasks(BackgroundTask): diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b77bb41f9..2ea98a4bc 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -33,7 +33,7 @@ async def coro() -> None: async with send_stream: try: - # may block in send if body_stream not consumed + # `send_stream.send` blocks until `body_stream` is called await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc @@ -72,12 +72,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) - - t = anyio.get_current_task() - if t.name == "anyio.from_thread.BlockingPortal._call_func": - # cancel stuck task due to discarded response - # see: https://github.com/encode/starlette/issues/1022 - task_group.cancel_scope.cancel() + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/responses.py b/starlette/responses.py index 98c8caf1b..c15fd2848 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -166,8 +166,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) await send({"type": "http.response.body", "body": self.body}) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class HTMLResponse(Response): @@ -264,8 +265,9 @@ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: task_group.start_soon(wrap, partial(self.stream_response, send)) await wrap(partial(self.listen_for_disconnect, receive)) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class FileResponse(Response): @@ -350,5 +352,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "more_body": more_body, } ) - if self.background is not None: - await self.background() + + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 98ce5e8bf..c7b65e8e4 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -92,10 +92,18 @@ async def _sleep(identifier, delay): await anyio.sleep(delay) print(identifier, "completed") + def _sleep_sync(identifier, delay): + import time + + print(identifier, "started") + time.sleep(delay) + print(identifier, "completed") + async def bg_task(request): background_tasks = BackgroundTasks() background_tasks.add_task(_sleep, "background task 1", 2) background_tasks.add_task(_sleep, "background task 2", 2) + background_tasks.add_task(_sleep_sync, "background task sync", 2) return Response(background=background_tasks) app = Starlette(