Skip to content

Commit

Permalink
replace BaseMiddleware cancellation after request send with closing r…
Browse files Browse the repository at this point in the history
…ecv_stream + http.disconnect in receive

fixes #1438
  • Loading branch information
jhominal committed Jul 2, 2022
1 parent 0b132ee commit f1d1e21
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 3 deletions.
40 changes: 37 additions & 3 deletions starlette/middleware/base.py
Expand Up @@ -4,12 +4,13 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


class BaseHTTPMiddleware:
Expand All @@ -24,19 +25,52 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

response_sent = anyio.Event()

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if response_sent.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
result = await func()
task_group.cancel_scope.cancel()
return result

task_group.start_soon(wrap, response_sent.wait)
message = await wrap(request.receive)

if response_sent.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_response_sent() -> None:
await response_sent.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. response_sent has been set.
return

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)

try:
Expand Down Expand Up @@ -71,7 +105,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
145 changes: 145 additions & 0 deletions tests/middleware/test_base.py
@@ -1,8 +1,10 @@
import contextvars

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
Expand Down Expand Up @@ -206,3 +208,146 @@ async def homepage(request):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content


@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

async def sleep_and_set():
# small delay to give BaseHTTPMiddleware a chance to cancel us
# this is required to make the test fail prior to fixing the issue
# so do not be surprised if you remove it and the test still passes
await anyio.sleep(0.1)
background_task_run.set()

async def endpoint_with_background_task(_):
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request, call_next):
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive():
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message):
if message["type"] == "http.response.body":
if not message.get("more_body", False):
response_complete.set()

await app(scope, receive, send)

assert background_task_run.is_set()


def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
async with anyio.create_task_group() as task_group:

async def cancel_on_disconnect():
while True:
message = await receive()
if message["type"] == "http.disconnect":
task_group.cancel_scope.cancel()
break

task_group.start_soon(cancel_on_disconnect)

await send(
{
"type": "http.response.body",
"body": b"first chunk, ",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": b"second chunk",
"more_body": True,
}
)
pytest.fail(
"http.disconnect should have been received and canceled the scope"
)

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
class DiscardingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await call_next(request)
return PlainTextResponse("Custom")

async def downstream_app(scope, receive, send):
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
],
}
)
await send(
{
"type": "http.response.body",
"body": b"first chunk, ",
"more_body": True,
}
)
await send(
{
"type": "http.response.body",
"body": b"second chunk",
"more_body": True,
}
)
message = await receive()
assert message["type"] == "http.disconnect"

app = DiscardingMiddleware(downstream_app)

client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

0 comments on commit f1d1e21

Please sign in to comment.