Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom route on route level #5897

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.routing import APIRoute
from fastapi.types import DecoratedCallable
from fastapi.utils import generate_unique_id
from starlette.applications import Starlette
Expand Down Expand Up @@ -454,6 +455,7 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -481,6 +483,7 @@ def get(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -509,6 +512,7 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -536,6 +540,7 @@ def put(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -564,6 +569,7 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -591,6 +597,7 @@ def post(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -619,6 +626,7 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -646,6 +654,7 @@ def delete(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -674,6 +683,7 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -701,6 +711,7 @@ def options(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -729,6 +740,7 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -756,6 +768,7 @@ def head(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -784,6 +797,7 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -811,6 +825,7 @@ def patch(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -839,6 +854,7 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
Expand Down Expand Up @@ -866,6 +882,7 @@ def trace(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down
18 changes: 18 additions & 0 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def api_route(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -658,6 +659,7 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -822,6 +824,7 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -850,6 +853,7 @@ def get(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -878,6 +882,7 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -906,6 +911,7 @@ def put(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -934,6 +940,7 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -962,6 +969,7 @@ def post(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -990,6 +998,7 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1018,6 +1027,7 @@ def delete(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1046,6 +1056,7 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1074,6 +1085,7 @@ def options(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1102,6 +1114,7 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1130,6 +1143,7 @@ def head(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1158,6 +1172,7 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1186,6 +1201,7 @@ def patch(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1214,6 +1230,7 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1243,6 +1260,7 @@ def trace(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_custom_route_class_for_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable
from urllib.request import Request

import pytest
from fastapi import APIRouter, FastAPI, HTTPException, status
from fastapi.openapi.models import Response
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient

app = FastAPI()
router = APIRouter()


class CustomRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
if "test_header" not in request.headers:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return await original_route_handler(request)

return custom_route_handler


@router.get("/a")
def get_a():
return {"msg": "A"}


@router.get("/b", route_class_override=CustomRoute)
def get_b():
return {"msg": "B"}


app.include_router(router=router, prefix="")


client = TestClient(app)


@pytest.mark.parametrize(
"path,expected_status,headers",
[
("/a", 200, {"test_header": "value"}),
("/a", 200, None),
("/b", 200, {"test_header": "value"}),
("/b", 400, None),
],
ids=[
"/a with test_header header",
"/a without test_header headers",
"/b with test_header headers",
"/b without test_header headers",
],
)
def test_get_path(path, expected_status, headers):
response = client.get(path, headers=headers)
assert response.status_code == expected_status