Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WSGI middleware not to explode quadratically in the case of a larger body #1329

Merged
merged 6 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 16 additions & 2 deletions tests/middleware/test_wsgi.py
@@ -1,5 +1,6 @@
import io
import sys
from typing import List
from typing import AsyncGenerator, List

import httpx
import pytest
Expand Down Expand Up @@ -67,6 +68,19 @@ async def test_wsgi_post() -> None:
assert response.text == '{"example": 123}'


@pytest.mark.asyncio
async def test_wsgi_put_more_body() -> None:
async def generate_body() -> AsyncGenerator[bytes, None]:
for _ in range(1024):
yield b"123456789abcdef\n" * 64

app = WSGIMiddleware(echo_body)
async with httpx.AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.put("/", content=generate_body())
assert response.status_code == 200
assert response.text == "123456789abcdef\n" * 64 * 1024


@pytest.mark.asyncio
async def test_wsgi_exception() -> None:
# Note that we're testing the WSGI app directly here.
Expand Down Expand Up @@ -120,6 +134,6 @@ def test_build_environ_encoding() -> None:
"body": b"",
"more_body": False,
}
environ = build_environ(scope, message, b"")
environ = build_environ(scope, message, io.BytesIO(b""))
assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1")
assert environ["HTTP_KEY"] == "value1,value2"
21 changes: 14 additions & 7 deletions uvicorn/middleware/wsgi.py
Expand Up @@ -19,7 +19,9 @@
from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp


def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ:
def build_environ(
scope: HTTPScope, message: ASGIReceiveEvent, body: io.BytesIO
) -> Environ:
"""
Builds a scope and request message into a WSGI environ object.
"""
Expand All @@ -31,7 +33,7 @@ def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> E
"SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": io.BytesIO(body),
"wsgi.input": body,
"wsgi.errors": sys.stdout,
"wsgi.multithread": True,
"wsgi.multiprocess": True,
Expand Down Expand Up @@ -105,12 +107,17 @@ async def __call__(
self, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body = message.get("body", b"")
body = io.BytesIO(message.get("body", b""))
more_body = message.get("more_body", False)
while more_body:
body_message: HTTPRequestEvent = await receive() # type: ignore[assignment]
body += body_message.get("body", b"")
more_body = body_message.get("more_body", False)
if more_body:
body.seek(0, io.SEEK_END)
while more_body:
body_message: HTTPRequestEvent = (
await receive() # type: ignore[assignment]
)
body.write(body_message.get("body", b""))
more_body = body_message.get("more_body", False)
body.seek(0)
environ = build_environ(self.scope, message, body)
self.loop = asyncio.get_event_loop()
wsgi = self.loop.run_in_executor(
Expand Down