From f1780001cea23f48a0aad4ee90cf3f2ea39cd5c4 Mon Sep 17 00:00:00 2001 From: Erik Aker Date: Fri, 31 Jul 2020 20:54:07 -0700 Subject: [PATCH 1/8] Limit queue size in base http middleware for backpressure Also separate streaming response in base http middleware from background --- starlette/middleware/base.py | 29 ++++++++++++++++++++++--- tests/middleware/test_base.py | 40 ++++++++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 6c6a43b16..1f6e7a228 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -27,7 +27,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(self, request: Request) -> Response: loop = asyncio.get_event_loop() - queue = asyncio.Queue() # type: asyncio.Queue + queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue scope = request.scope receive = request.receive @@ -42,18 +42,41 @@ async def coro() -> None: task = loop.create_task(coro()) message = await queue.get() if message is None: + queue.task_done() task.result() raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: + def streaming_predicate( + msg: typing.Optional[dict], more_body: bool = True + ) -> bool: + return ( + msg is not None + and msg["type"] == "http.response.body" + and "more_body" in msg + and msg["more_body"] is more_body + ) + + # In non-streaming responses, there will be one message to emit + message = await queue.get() + queue.task_done() + assert message["type"] == "http.response.body" + yield message.get("body", b"") + + while streaming_predicate(message, more_body=True): message = await queue.get() + queue.task_done() if message is None: break assert message["type"] == "http.response.body" yield message.get("body", b"") - task.result() + + try: + task.result() # check for exceptions and raise if present + except asyncio.exceptions.InvalidStateError: + # task is not completed (which could be due to background) + pass response = StreamingResponse( status_code=message["status"], content=body_stream() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 048dd9ffb..ff3a453d4 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -3,7 +3,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, StreamingResponse from starlette.routing import Route from starlette.testclient import TestClient @@ -143,3 +143,41 @@ def homepage(request): def test_middleware_repr(): middleware = Middleware(CustomMiddleware) assert repr(middleware) == "Middleware(CustomMiddleware)" + + +async def numbers_stream(minimum, maximum): + yield ("") + + +async def somthing_broken(minimum, maximum): + yield ("") -async def somthing_broken(minimum, maximum): +async def somthing_broken(minimum, maximum, error_at=2): + if error_at <= 0: + raise RuntimeError("This is a stream that breaks when it starts") yield ("