diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..2ea98a4bc 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -33,6 +33,7 @@ async def coro() -> None: async with send_stream: try: + # `send_stream.send` blocks until `body_stream` is called await self.app(scope, request.receive, send_stream.send) except Exception as exc: app_exc = exc diff --git a/starlette/responses.py b/starlette/responses.py index 98c8caf1b..c15fd2848 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -166,8 +166,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) await send({"type": "http.response.body", "body": self.body}) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class HTMLResponse(Response): @@ -264,8 +265,9 @@ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: task_group.start_soon(wrap, partial(self.stream_response, send)) await wrap(partial(self.listen_for_disconnect, receive)) - if self.background is not None: - await self.background() + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() class FileResponse(Response): @@ -350,5 +352,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: "more_body": more_body, } ) - if self.background is not None: - await self.background() + + with anyio.CancelScope(shield=True): + if self.background is not None: + await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 976d77b86..c7b65e8e4 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,11 +1,13 @@ import contextvars +import anyio import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTasks from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send @@ -84,6 +86,34 @@ def test_custom_middleware(test_client_factory): assert text == "Hello, world!" +def test_background_tasks(test_client_factory): + async def _sleep(identifier, delay): + print(identifier, "started") + await anyio.sleep(delay) + print(identifier, "completed") + + def _sleep_sync(identifier, delay): + import time + + print(identifier, "started") + time.sleep(delay) + print(identifier, "completed") + + async def bg_task(request): + background_tasks = BackgroundTasks() + background_tasks.add_task(_sleep, "background task 1", 2) + background_tasks.add_task(_sleep, "background task 2", 2) + background_tasks.add_task(_sleep_sync, "background task sync", 2) + return Response(background=background_tasks) + + app = Starlette( + routes=[Route("/bg-task", bg_task)], middleware=[Middleware(CustomMiddleware)] + ) + client = test_client_factory(app) + response = client.get("/bg-task") + assert response.text == "" + + def test_state_data_across_multiple_middlewares(test_client_factory): expected_value1 = "foo" expected_value2 = "bar" @@ -141,12 +171,25 @@ def homepage(request): def test_fully_evaluated_response(test_client_factory): # Test for https://github.com/encode/starlette/issues/1022 - class CustomMiddleware(BaseHTTPMiddleware): + class ConsumeMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + resp = await call_next(request) + + async def _send(m): + pass + + await resp.stream_response(_send) # type: ignore + + return PlainTextResponse("Custom") + + class DiscardMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): await call_next(request) return PlainTextResponse("Custom") - app = Starlette(middleware=[Middleware(CustomMiddleware)]) + app = Starlette( + middleware=[Middleware(ConsumeMiddleware), Middleware(DiscardMiddleware)] + ) client = test_client_factory(app) response = client.get("/does_not_exist")