Skip to content

Commit

Permalink
feat: add custom route on route level
Browse files Browse the repository at this point in the history
  • Loading branch information
arkadybag committed Jan 18, 2023
1 parent 5905c3f commit fb7ea03
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
18 changes: 18 additions & 0 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def api_route(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -658,6 +659,7 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -822,6 +824,7 @@ def get(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -850,6 +853,7 @@ def get(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -878,6 +882,7 @@ def put(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -906,6 +911,7 @@ def put(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -934,6 +940,7 @@ def post(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -962,6 +969,7 @@ def post(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -990,6 +998,7 @@ def delete(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1018,6 +1027,7 @@ def delete(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1046,6 +1056,7 @@ def options(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1074,6 +1085,7 @@ def options(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1102,6 +1114,7 @@ def head(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1130,6 +1143,7 @@ def head(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1158,6 +1172,7 @@ def patch(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1186,6 +1201,7 @@ def patch(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down Expand Up @@ -1214,6 +1230,7 @@ def trace(
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
Expand Down Expand Up @@ -1243,6 +1260,7 @@ def trace(
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
route_class_override=route_class_override,
callbacks=callbacks,
openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_custom_route_class_for_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Callable
from urllib.request import Request

import pytest
from fastapi import APIRouter, FastAPI, HTTPException, status
from fastapi.openapi.models import Response
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient

app = FastAPI()
router = APIRouter()


class CustomRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
if "test_header" not in request.headers:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return await original_route_handler(request)

return custom_route_handler


@router.get("/a")
def get_a():
return {"msg": "A"}


@router.get("/b", route_class_override=CustomRoute)
def get_b():
return {"msg": "B"}


app.include_router(router=router, prefix="")


client = TestClient(app)


@pytest.mark.parametrize(
"path,expected_status,headers",
[
("/a", 200, {"test_header": "value"}),
("/a", 200, None),
("/b", 200, {"test_header": "value"}),
("/b", 400, None),
],
ids=[
"/a with test_header header",
"/a without test_header headers",
"/b with test_header headers",
"/b without test_header headers",
],
)
def test_get_path(path, expected_status, headers):
response = client.get(path, headers=headers)
assert response.status_code == expected_status

0 comments on commit fb7ea03

Please sign in to comment.