Skip to content

Commit

Permalink
prototype design where BaseHTTPMiddleware works without cancelling do…
Browse files Browse the repository at this point in the history
…wnstream app
  • Loading branch information
jhominal committed Jun 29, 2022
1 parent 0b132ee commit c98203e
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions starlette/middleware/base.py
Expand Up @@ -4,13 +4,23 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]

T = typing.TypeVar("T")


async def _call_and_cancel(
func: typing.Callable[[], typing.Awaitable[T]], cancel_scope: anyio.CancelScope
) -> T:
result = await func()
cancel_scope.cancel()
return result


class BaseHTTPMiddleware:
def __init__(
Expand All @@ -24,19 +34,56 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

app_disconnected = anyio.Event()

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
if app_disconnected.is_set():
return {"type": "http.disconnect"}

async with anyio.create_task_group() as task_group:
# app_disconnected is set, cancel parent scope to cancel
# request.receive
task_group.start_soon(
_call_and_cancel,
app_disconnected.wait,
task_group.cancel_scope,
)
# if request.receive returns a message, cancel the task_group to
# exit its block
message = await _call_and_cancel(
request.receive, task_group.cancel_scope
)

if app_disconnected.is_set():
return {"type": "http.disconnect"}

return message

async def close_recv_stream_on_disconnect() -> None:
await app_disconnected.wait()
recv_stream.close()

async def send_no_error(message: Message) -> None:
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. app_disconnected has been set.
return

async def coro() -> None:
nonlocal app_exc

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc

task_group.start_soon(close_recv_stream_on_disconnect)
task_group.start_soon(coro)

try:
Expand Down Expand Up @@ -71,7 +118,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
app_disconnected.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down

0 comments on commit c98203e

Please sign in to comment.