Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BackgroundTask execution outside of request/response cycle #2176

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions starlette/applications.py
Expand Up @@ -3,6 +3,7 @@

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -94,6 +95,7 @@ def build_middleware_stack(self) -> ASGIApp:

middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ [Middleware(BackgroundTaskMiddleware)]
+ self.user_middleware
+ [
Middleware(
Expand Down
37 changes: 37 additions & 0 deletions starlette/middleware/background.py
@@ -0,0 +1,37 @@
from typing import List, cast

from starlette.background import BackgroundTask
from starlette.types import ASGIApp, Receive, Scope, Send

# consider this a private implementation detail subject to change
# do not rely on this key
_SCOPE_KEY = "starlette._background"


_BackgroundTaskList = List[BackgroundTask]


def is_background_task_middleware_installed(scope: Scope) -> bool:
return _SCOPE_KEY in scope


def add_tasks(scope: Scope, __task: BackgroundTask) -> None:
if _SCOPE_KEY not in scope: # pragma: no cover
raise RuntimeError(
"`add_tasks` can only be used if `BackgroundTaskMIddleware is installed"
)
cast(_BackgroundTaskList, scope[_SCOPE_KEY]).append(__task)


class BackgroundTaskMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
tasks: _BackgroundTaskList
scope[_SCOPE_KEY] = tasks = []
try:
await self._app(scope, receive, send)
finally:
for task in tasks:
await task()
22 changes: 22 additions & 0 deletions starlette/responses.py
Expand Up @@ -16,6 +16,7 @@
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.middleware import background
from starlette.types import Receive, Scope, Send

if sys.version_info >= (3, 8): # pragma: no cover
Expand Down Expand Up @@ -161,6 +162,13 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None

await send(
{
"type": "http.response.start",
Expand Down Expand Up @@ -267,6 +275,13 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None

async with anyio.create_task_group() as task_group:

async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
Expand Down Expand Up @@ -330,6 +345,13 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
self.headers.setdefault("etag", etag)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if (
self.background is not None
and background.is_background_task_middleware_installed(scope)
):
background.add_tasks(scope, self.background)
self.background = None

if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
Expand Down
52 changes: 52 additions & 0 deletions tests/middleware/test_base.py
Expand Up @@ -8,6 +8,7 @@
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -793,6 +794,57 @@ async def wrapped_receive() -> Message:
assert resp.status_code == 200


@pytest.mark.anyio
async def test_background_tasks_client_disconnect() -> None:
# test for https://github.com/encode/starlette/issues/1438
container: List[str] = []

disconnected = anyio.Event()

async def slow_background() -> None:
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
container.append("called")

app: ASGIApp
app = PlainTextResponse("hi!", background=BackgroundTask(slow_background))

async def dispatch(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await call_next(request)

app = BaseHTTPMiddleware(app, dispatch=dispatch)

app = BackgroundTaskMiddleware(app)

async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
await disconnected.wait()
while True:
yield {"type": "http.disconnect"}

async def send_gen() -> AsyncGenerator[None, Message]:
while True:
msg = yield
if msg["type"] == "http.response.body" and not msg.get("more_body", False):
disconnected.set()

scope = {"type": "http", "method": "GET", "path": "/"}

async with AsyncExitStack() as stack:
recv = recv_gen()
stack.push_async_callback(recv.aclose)
send = send_gen()
stack.push_async_callback(send.aclose)
await send.__anext__()
await app(scope, recv.__aiter__().__anext__, send.asend)

assert container == ["called"]


CallNext = Callable[[Request], Awaitable[Response]]


Expand Down
81 changes: 71 additions & 10 deletions tests/test_background.py
@@ -1,13 +1,76 @@
from typing import Callable
from tempfile import NamedTemporaryFile
from typing import Any, AsyncIterable, Callable, List

import pytest

from starlette.background import BackgroundTask, BackgroundTasks
from starlette.responses import Response
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.responses import FileResponse, Response, StreamingResponse
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send

TestClientFactory = Callable[[ASGIApp], TestClient]

def test_async_task(test_client_factory):

@pytest.fixture(
params=[[], [BackgroundTaskMiddleware]],
ids=["without BackgroundTaskMiddleware", "with BackgroundTaskMiddleware"],
)
def test_client_factory_mw(
test_client_factory: TestClientFactory, request: Any
) -> TestClientFactory:
mw_stack: List[Callable[[ASGIApp], ASGIApp]] = request.param

def client_factory(app: ASGIApp) -> TestClient:
for mw in mw_stack:
app = mw(app)
return test_client_factory(app)

return client_factory


def response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
response = Response(b"task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)

return app


def file_response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
with NamedTemporaryFile("wb+") as f:
f.write(b"task initiated")
f.seek(0)
response = FileResponse(f.name, media_type="text/plain", background=task)
await response(scope, receive, send)

return app


def streaming_response_app_factory(task: BackgroundTask) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send):
async def stream() -> AsyncIterable[bytes]:
yield b"task initiated"

response = StreamingResponse(stream(), media_type="text/plain", background=task)
await response(scope, receive, send)

return app


@pytest.mark.parametrize(
"app_factory",
[
response_app_factory,
streaming_response_app_factory,
file_response_app_factory,
],
)
def test_async_task(
test_client_factory_mw: TestClientFactory,
app_factory: Callable[[BackgroundTask], ASGIApp],
):
TASK_COMPLETE = False

async def async_task():
Expand All @@ -16,17 +79,15 @@ async def async_task():

task = BackgroundTask(async_task)

async def app(scope, receive, send):
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
app = app_factory(task)

client = test_client_factory(app)
client = test_client_factory_mw(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE


def test_sync_task(test_client_factory):
def test_sync_task(test_client_factory: TestClientFactory):
TASK_COMPLETE = False

def sync_task():
Expand All @@ -45,7 +106,7 @@ async def app(scope, receive, send):
assert TASK_COMPLETE


def test_multiple_tasks(test_client_factory: Callable[..., TestClient]):
def test_multiple_tasks(test_client_factory: TestClientFactory):
TASK_COUNTER = 0

def increment(amount):
Expand All @@ -69,7 +130,7 @@ async def app(scope, receive, send):


def test_multi_tasks_failure_avoids_next_execution(
test_client_factory: Callable[..., TestClient]
test_client_factory: TestClientFactory,
) -> None:
TASK_COUNTER = 0

Expand Down
7 changes: 4 additions & 3 deletions tests/test_responses.py
Expand Up @@ -8,6 +8,7 @@

from starlette import status
from starlette.background import BackgroundTask
from starlette.middleware.background import BackgroundTaskMiddleware
from starlette.requests import Request
from starlette.responses import (
FileResponse,
Expand Down Expand Up @@ -116,7 +117,7 @@ async def numbers_for_cleanup(start=1, stop=5):
await response(scope, receive, send)

assert filled_by_bg_task == ""
client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
assert filled_by_bg_task == "6, 7, 8, 9"
Expand All @@ -140,7 +141,7 @@ async def __anext__(self):
response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)

client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
assert response.text == "12345"

Expand Down Expand Up @@ -231,7 +232,7 @@ async def app(scope, receive, send):
await response(scope, receive, send)

assert filled_by_bg_task == ""
client = test_client_factory(app)
client = test_client_factory(BackgroundTaskMiddleware(app))
response = client.get("/")
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_200_OK
Expand Down