Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Allow custom middlewares to raise HTTPExceptions and propagate them #2036

Merged
merged 5 commits into from Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastapi/routing.py
Expand Up @@ -165,6 +165,8 @@ async def app(request: Request) -> Response:
body = await request.json()
except json.JSONDecodeError as e:
raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=400, detail="There was an error parsing the body"
Expand Down
95 changes: 95 additions & 0 deletions tests/test_custom_middleware_exception.py
@@ -0,0 +1,95 @@
import os
from typing import Optional

from fastapi import APIRouter, FastAPI, File, UploadFile
from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient

app = FastAPI()

router = APIRouter()


class ContentSizeLimitMiddleware:
""" Content size limiting middleware for ASGI applications
Args:
app (ASGI application): ASGI application
max_content_size (optional): the maximum content size allowed in bytes, None for no limit
"""

def __init__(self, app: APIRouter, max_content_size: Optional[int] = None):
self.app = app
self.max_content_size = max_content_size

def receive_wrapper(self, receive):
received = 0

async def inner():
nonlocal received
message = await receive()
if message["type"] != "http.request":
return message # pragma: no cover

body_len = len(message.get("body", b""))
received += body_len
if received > self.max_content_size:
raise HTTPException(
422,
detail={
"name": "ContentSizeLimitExceeded",
"code": 999,
"message": "File limit exceeded",
},
)
return message

return inner

async def __call__(self, scope, receive, send):
if scope["type"] != "http" or self.max_content_size is None:
await self.app(scope, receive, send)
return

wrapper = self.receive_wrapper(receive)
await self.app(scope, wrapper, send)


@router.post("/middleware")
def run_middleware(file: UploadFile = File(..., description="Big File")):
return {"message": "OK"}


app.include_router(router)
app.add_middleware(ContentSizeLimitMiddleware, max_content_size=2 ** 8)


client = TestClient(app)


def test_custom_middleware_exception(tmpdir):
default_pydantic_max_size = 2 ** 16
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"x" * (default_pydantic_max_size + 1))

with client:
response = client.post("/middleware", files={"file": open(path, "rb")})
assert response.status_code == 422, response.text
assert response.json() == {
"detail": {
"name": "ContentSizeLimitExceeded",
"code": 999,
"message": "File limit exceeded",
}
}


def test_custom_middleware_exception_not_raised(tmpdir):
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"<file content>")

with client:
response = client.post("/middleware", files={"file": open(path, "rb")})
assert response.status_code == 200, response.text
assert response.json() == {"message": "OK"}