Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RequestSizeLimitMiddleware and RequestTimeoutMiddleware #2328

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
183 changes: 183 additions & 0 deletions starlette/middleware/limits.py
@@ -0,0 +1,183 @@
from __future__ import annotations

from typing import ClassVar

import anyio

from starlette.exceptions import HTTPException
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

DEFAULT_MAX_REQUEST_SIZE = 2_621_440 # 2.5MB, same as Django (https://docs.djangoproject.com/en/1.11/ref/settings/#data-upload-max-memory-size)


class _TooLarge(HTTPException):
msg: ClassVar[str]

def __init__(self, limit_bytes: int | None) -> None:
self.limit = limit_bytes
super().__init__(
status_code=413,
detail=(
self.msg + f" Max allowed size is {limit_bytes} bytes."
if limit_bytes
else self.msg
),
)


class RequestTooLarge(_TooLarge):
"""The request body exceeded the configured limit."""

msg = "Request body is too large."


class ChunkTooLarge(_TooLarge):
"""A chunk exceeded the configured limit."""

msg = "Chunk size is too large."


class RequestSizeLimitMiddleware:
def __init__(
self,
app: ASGIApp,
*,
max_request_size: int | None = DEFAULT_MAX_REQUEST_SIZE,
max_chunk_size: int | None = None,
include_limits_in_error_responses: bool = True,
) -> None:
self.app = app
self.max_request_size = max_request_size
self.max_chunk_size = max_chunk_size
self.include_limits_in_error_responses = include_limits_in_error_responses

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

total_size = 0

async def rcv() -> Message:
nonlocal total_size
message = await receive()
chunk_size = len(message.get("body", b""))
if self.max_chunk_size is not None and chunk_size > self.max_chunk_size:
raise ChunkTooLarge(
self.max_chunk_size
if self.include_limits_in_error_responses
else None
)
total_size += chunk_size
if self.max_request_size is not None and total_size > self.max_request_size:
raise RequestTooLarge(
self.max_request_size
if self.include_limits_in_error_responses
else None
)
return message
Comment on lines +62 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I’m confused as to what you are suggesting: are you saying that we should also enforce that limit, that that limit is already enforced elsewhere, and this PR (or #2174) is not needed, or something else?

Would it be more concise and efficient to change here to judge the request content-length?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the user-defined limits? Or in addition to?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What @abersheeran meant is to check the Content-Length instead of having the logic of adding up chunk_sizes - which is actually what Django does: https://github.com/django/django/blob/6daf86058bb6fb922eb3fe3abae6f5c0e645020c/django/http/request.py#L323-L347.

Django also has this logic on the multipart parser: https://github.com/django/django/blob/594873befbbec13a2d9a048a361757dd3cf178da/django/http/multipartparser.py#L241-L248.


await self.app(scope, rcv, send)


class _Timeout(HTTPException):
msg: ClassVar[str]

def __init__(self, limit_seconds: float | None) -> None:
self.limit = limit_seconds
super().__init__(
status_code=408,
detail=(
self.msg + f" Max allowed time is {limit_seconds} seconds."
if limit_seconds
else self.msg
),
)


class ReceiveTimeout(_Timeout):
"""The receive exceeded the configured limit."""

msg = "Client was too slow sending data."


class SendTimeout(_Timeout):
"""The send exceeded the configured limit."""

msg = "Client was too slow receiving data."


class RequestTimeoutMiddleware:
def __init__(
self,
app: ASGIApp,
request_timeout_seconds: float | None = None,
receive_timeout_seconds: float | None = None,
send_timeout_seconds: float | None = None,
include_limits_in_error_responses: bool = True,
) -> None:
self.app = app
self.timeout_seconds = request_timeout_seconds
self.receive_timeout_seconds = receive_timeout_seconds
self.send_timeout_seconds = send_timeout_seconds
self.include_limits_in_error_responses = include_limits_in_error_responses

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

send_with_timeout: Send
if self.send_timeout_seconds:

async def snd(message: Message) -> None:
try:
with anyio.fail_after(self.send_timeout_seconds):
await send(message)
except TimeoutError:
raise SendTimeout(
self.send_timeout_seconds
if self.include_limits_in_error_responses
else None
)

send_with_timeout = snd
else:
send_with_timeout = send

receive_with_timeout: Receive
if self.receive_timeout_seconds:

async def rcv() -> Message:
try:
with anyio.fail_after(self.receive_timeout_seconds):
return await receive()
except TimeoutError:
raise ReceiveTimeout(
self.receive_timeout_seconds
if self.include_limits_in_error_responses
else None
)

receive_with_timeout = rcv
else:
receive_with_timeout = receive

if self.timeout_seconds is not None:
try:
with anyio.fail_after(self.timeout_seconds):
await self.app(scope, receive_with_timeout, send_with_timeout)
except TimeoutError:
if self.include_limits_in_error_responses:
await PlainTextResponse(
content=f"Request exceeded the timeout of {self.timeout_seconds} seconds.", # noqa: E501
status_code=408,
)(scope, receive, send)
else:
await PlainTextResponse(
content="Request exceeded the timeout.",
status_code=408,
)(scope, receive, send)
else:
await self.app(scope, receive_with_timeout, send_with_timeout)