Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unexpected background task cancellation #1699

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions starlette/middleware/base.py
Expand Up @@ -33,6 +33,7 @@ async def coro() -> None:

async with send_stream:
try:
# `send_stream.send` blocks until `body_stream` is called
await self.app(scope, request.receive, send_stream.send)
except Exception as exc:
app_exc = exc
Expand Down
16 changes: 10 additions & 6 deletions starlette/responses.py
Expand Up @@ -166,8 +166,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
)
await send({"type": "http.response.body", "body": self.body})

if self.background is not None:
await self.background()
with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()


class HTMLResponse(Response):
Expand Down Expand Up @@ -264,8 +265,9 @@ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))

if self.background is not None:
await self.background()
with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()


class FileResponse(Response):
Expand Down Expand Up @@ -350,5 +352,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"more_body": more_body,
}
)
if self.background is not None:
await self.background()

with anyio.CancelScope(shield=True):
if self.background is not None:
await self.background()
49 changes: 46 additions & 3 deletions tests/middleware/test_base.py
@@ -1,11 +1,13 @@
import contextvars

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTasks
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send

Expand Down Expand Up @@ -84,6 +86,34 @@ def test_custom_middleware(test_client_factory):
assert text == "Hello, world!"


def test_background_tasks(test_client_factory):
async def _sleep(identifier, delay):
print(identifier, "started")
await anyio.sleep(delay)
print(identifier, "completed")

def _sleep_sync(identifier, delay):
import time

print(identifier, "started")
time.sleep(delay)
print(identifier, "completed")

async def bg_task(request):
background_tasks = BackgroundTasks()
background_tasks.add_task(_sleep, "background task 1", 2)
background_tasks.add_task(_sleep, "background task 2", 2)
background_tasks.add_task(_sleep_sync, "background task sync", 2)
return Response(background=background_tasks)

app = Starlette(
routes=[Route("/bg-task", bg_task)], middleware=[Middleware(CustomMiddleware)]
)
client = test_client_factory(app)
response = client.get("/bg-task")
assert response.text == ""


def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value1 = "foo"
expected_value2 = "bar"
Expand Down Expand Up @@ -141,12 +171,25 @@ def homepage(request):

def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(BaseHTTPMiddleware):
class ConsumeMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
resp = await call_next(request)

async def _send(m):
pass

await resp.stream_response(_send) # type: ignore

return PlainTextResponse("Custom")

class DiscardMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

app = Starlette(middleware=[Middleware(CustomMiddleware)])
app = Starlette(
middleware=[Middleware(ConsumeMiddleware), Middleware(DiscardMiddleware)]
)

client = test_client_factory(app)
response = client.get("/does_not_exist")
Expand Down