diff --git a/Pipfile b/Pipfile index c7cdb944e0ddc..f59b751a907d8 100644 --- a/Pipfile +++ b/Pipfile @@ -29,6 +29,7 @@ starlette = "==0.12.8" pydantic = "==0.32.2" databases = {extras = ["sqlite"],version = "*"} hypercorn = "*" +orjson = "*" [requires] python_version = "3.6" diff --git a/fastapi/applications.py b/fastapi/applications.py index 7836801415e9b..28cbd319b79e9 100644 --- a/fastapi/applications.py +++ b/fastapi/applications.py @@ -33,11 +33,13 @@ def __init__( version: str = "0.1.0", openapi_url: Optional[str] = "/openapi.json", openapi_prefix: str = "", + default_response_class: Type[Response] = JSONResponse, docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect", **extra: Dict[str, Any], ) -> None: + self.default_response_class = default_response_class self._debug = debug self.router: routing.APIRouter = routing.APIRouter( routes, dependency_overrides_provider=self @@ -144,7 +146,7 @@ def add_api_route( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> None: self.router.add_api_route( @@ -166,7 +168,7 @@ def add_api_route( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -190,7 +192,7 @@ def api_route( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: def decorator(func: Callable) -> Callable: @@ -213,7 +215,7 @@ def decorator(func: Callable) -> Callable: response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) return func @@ -240,6 +242,7 @@ def include_router( tags: List[str] = None, dependencies: Sequence[Depends] = None, responses: Dict[Union[int, str], Dict[str, Any]] = None, + default_response_class: Optional[Type[Response]] = None, ) -> None: self.router.include_router( router, @@ -247,6 +250,8 @@ def include_router( tags=tags, dependencies=dependencies, responses=responses or {}, + default_response_class=default_response_class + or self.default_response_class, ) def get( @@ -268,7 +273,7 @@ def get( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.get( @@ -288,7 +293,7 @@ def get( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -311,7 +316,7 @@ def put( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.put( @@ -331,7 +336,7 @@ def put( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -354,7 +359,7 @@ def post( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.post( @@ -374,7 +379,7 @@ def post( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -397,7 +402,7 @@ def delete( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.delete( @@ -417,7 +422,7 @@ def delete( operation_id=operation_id, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -440,7 +445,7 @@ def options( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.options( @@ -460,7 +465,7 @@ def options( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -483,7 +488,7 @@ def head( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.head( @@ -503,7 +508,7 @@ def head( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -526,7 +531,7 @@ def patch( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.patch( @@ -546,7 +551,7 @@ def patch( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) @@ -569,7 +574,7 @@ def trace( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.router.trace( @@ -589,6 +594,6 @@ def trace( response_model_by_alias=response_model_by_alias, response_model_skip_defaults=response_model_skip_defaults, include_in_schema=include_in_schema, - response_class=response_class, + response_class=response_class or self.default_response_class, name=name, ) diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 6c987a29fa915..d9ec7bd6e3c57 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -151,6 +151,10 @@ def get_openapi_path( security_schemes: Dict[str, Any] = {} definitions: Dict[str, Any] = {} assert route.methods is not None, "Methods must be a list" + assert ( + route.response_class and route.response_class.media_type + ), "A response class with media_type is needed to generate OpenAPI" + route_response_media_type: str = route.response_class.media_type if route.include_in_schema: for method in route.methods: operation = get_openapi_operation_metadata(route=route, method=method) @@ -185,7 +189,7 @@ def get_openapi_path( field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) response.setdefault("content", {}).setdefault( - route.response_class.media_type, {} + route_response_media_type, {} )["schema"] = response_schema status_text: Optional[str] = status_code_ranges.get( str(additional_status_code).upper() @@ -213,7 +217,7 @@ def get_openapi_path( ] = route.response_description operation.setdefault("responses", {}).setdefault( status_code, {} - ).setdefault("content", {}).setdefault(route.response_class.media_type, {})[ + ).setdefault("content", {}).setdefault(route_response_media_type, {})[ "schema" ] = response_schema diff --git a/fastapi/routing.py b/fastapi/routing.py index f7cec9e45a480..aeafd07187e81 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -200,7 +200,7 @@ def __init__( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Optional[Type[Response]] = None, dependency_overrides_provider: Any = None, ) -> None: self.path = path @@ -215,9 +215,6 @@ def __init__( ) self.response_model = response_model if self.response_model: - assert lenient_issubclass( - response_class, JSONResponse - ), "To declare a type the response must be a JSON response" response_name = "Response_" + self.unique_id self.response_field: Optional[Field] = Field( name=response_name, @@ -299,7 +296,7 @@ def __init__( dependant=self.dependant, body_field=self.body_field, status_code=self.status_code, - response_class=self.response_class, + response_class=self.response_class or JSONResponse, response_field=self.secure_cloned_response_field, response_model_include=self.response_model_include, response_model_exclude=self.response_model_exclude, @@ -346,7 +343,7 @@ def add_api_route( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> None: route = self.route_class( @@ -394,7 +391,7 @@ def api_route( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: def decorator(func: Callable) -> Callable: @@ -445,6 +442,7 @@ def include_router( tags: List[str] = None, dependencies: Sequence[params.Depends] = None, responses: Dict[Union[int, str], Dict[str, Any]] = None, + default_response_class: Optional[Type[Response]] = None, ) -> None: if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" @@ -484,7 +482,7 @@ def include_router( response_model_by_alias=route.response_model_by_alias, response_model_skip_defaults=route.response_model_skip_defaults, include_in_schema=route.include_in_schema, - response_class=route.response_class, + response_class=route.response_class or default_response_class, name=route.name, ) elif isinstance(route, routing.Route): @@ -523,10 +521,9 @@ def get( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: - return self.api_route( path=path, response_model=response_model, @@ -568,7 +565,7 @@ def put( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -612,7 +609,7 @@ def post( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -656,7 +653,7 @@ def delete( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -700,7 +697,7 @@ def options( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -744,7 +741,7 @@ def head( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -788,7 +785,7 @@ def patch( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( @@ -832,7 +829,7 @@ def trace( response_model_by_alias: bool = True, response_model_skip_defaults: bool = False, include_in_schema: bool = True, - response_class: Type[Response] = JSONResponse, + response_class: Type[Response] = None, name: str = None, ) -> Callable: return self.api_route( diff --git a/pyproject.toml b/pyproject.toml index 942b0726864d8..06ca453d82f7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ test = [ "email_validator", "sqlalchemy", "databases[sqlite]", + "orjson" ] doc = [ "mkdocs", diff --git a/tests/test_default_response_class.py b/tests/test_default_response_class.py new file mode 100644 index 0000000000000..4905945b55315 --- /dev/null +++ b/tests/test_default_response_class.py @@ -0,0 +1,216 @@ +from typing import Any + +import orjson +from fastapi import APIRouter, FastAPI +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse +from starlette.testclient import TestClient + + +class ORJSONResponse(JSONResponse): + media_type = "application/x-orjson" + + def render(self, content: Any) -> bytes: + return orjson.dumps(content) + + +class OverrideResponse(JSONResponse): + media_type = "application/x-override" + + +app = FastAPI(default_response_class=ORJSONResponse) +router_a = APIRouter() +router_a_a = APIRouter() +router_a_b_override = APIRouter() # Overrides default class +router_b_override = APIRouter() # Overrides default class +router_b_a = APIRouter() +router_b_a_c_override = APIRouter() # Overrides default class again + + +@app.get("/") +def get_root(): + return {"msg": "Hello World"} + + +@app.get("/override", response_class=PlainTextResponse) +def get_path_override(): + return "Hello World" + + +@router_a.get("/") +def get_a(): + return {"msg": "Hello A"} + + +@router_a.get("/override", response_class=PlainTextResponse) +def get_a_path_override(): + return "Hello A" + + +@router_a_a.get("/") +def get_a_a(): + return {"msg": "Hello A A"} + + +@router_a_a.get("/override", response_class=PlainTextResponse) +def get_a_a_path_override(): + return "Hello A A" + + +@router_a_b_override.get("/") +def get_a_b(): + return "Hello A B" + + +@router_a_b_override.get("/override", response_class=HTMLResponse) +def get_a_b_path_override(): + return "Hello A B" + + +@router_b_override.get("/") +def get_b(): + return "Hello B" + + +@router_b_override.get("/override", response_class=HTMLResponse) +def get_b_path_override(): + return "Hello B" + + +@router_b_a.get("/") +def get_b_a(): + return "Hello B A" + + +@router_b_a.get("/override", response_class=HTMLResponse) +def get_b_a_path_override(): + return "Hello B A" + + +@router_b_a_c_override.get("/") +def get_b_a_c(): + return "Hello B A C" + + +@router_b_a_c_override.get("/override", response_class=OverrideResponse) +def get_b_a_c_path_override(): + return {"msg": "Hello B A C"} + + +router_b_a.include_router( + router_b_a_c_override, prefix="/c", default_response_class=HTMLResponse +) +router_b_override.include_router(router_b_a, prefix="/a") +router_a.include_router(router_a_a, prefix="/a") +router_a.include_router( + router_a_b_override, prefix="/b", default_response_class=PlainTextResponse +) +app.include_router(router_a, prefix="/a") +app.include_router( + router_b_override, prefix="/b", default_response_class=PlainTextResponse +) + + +client = TestClient(app) + +orjson_type = "application/x-orjson" +text_type = "text/plain; charset=utf-8" +html_type = "text/html; charset=utf-8" +override_type = "application/x-override" + + +def test_app(): + with client: + response = client.get("/") + assert response.json() == {"msg": "Hello World"} + assert response.headers["content-type"] == orjson_type + + +def test_app_override(): + with client: + response = client.get("/override") + assert response.content == b"Hello World" + assert response.headers["content-type"] == text_type + + +def test_router_a(): + with client: + response = client.get("/a") + assert response.json() == {"msg": "Hello A"} + assert response.headers["content-type"] == orjson_type + + +def test_router_a_override(): + with client: + response = client.get("/a/override") + assert response.content == b"Hello A" + assert response.headers["content-type"] == text_type + + +def test_router_a_a(): + with client: + response = client.get("/a/a") + assert response.json() == {"msg": "Hello A A"} + assert response.headers["content-type"] == orjson_type + + +def test_router_a_a_override(): + with client: + response = client.get("/a/a/override") + assert response.content == b"Hello A A" + assert response.headers["content-type"] == text_type + + +def test_router_a_b(): + with client: + response = client.get("/a/b") + assert response.content == b"Hello A B" + assert response.headers["content-type"] == text_type + + +def test_router_a_b_override(): + with client: + response = client.get("/a/b/override") + assert response.content == b"Hello A B" + assert response.headers["content-type"] == html_type + + +def test_router_b(): + with client: + response = client.get("/b") + assert response.content == b"Hello B" + assert response.headers["content-type"] == text_type + + +def test_router_b_override(): + with client: + response = client.get("/b/override") + assert response.content == b"Hello B" + assert response.headers["content-type"] == html_type + + +def test_router_b_a(): + with client: + response = client.get("/b/a") + assert response.content == b"Hello B A" + assert response.headers["content-type"] == text_type + + +def test_router_b_a_override(): + with client: + response = client.get("/b/a/override") + assert response.content == b"Hello B A" + assert response.headers["content-type"] == html_type + + +def test_router_b_a_c(): + with client: + response = client.get("/b/a/c") + assert response.content == b"Hello B A C" + assert response.headers["content-type"] == html_type + + +def test_router_b_a_c_override(): + with client: + response = client.get("/b/a/c/override") + assert response.json() == {"msg": "Hello B A C"} + assert response.headers["content-type"] == override_type diff --git a/tests/test_default_response_class_router.py b/tests/test_default_response_class_router.py new file mode 100644 index 0000000000000..95aada4097df5 --- /dev/null +++ b/tests/test_default_response_class_router.py @@ -0,0 +1,206 @@ +from fastapi import APIRouter, FastAPI +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse +from starlette.testclient import TestClient + + +class OverrideResponse(JSONResponse): + media_type = "application/x-override" + + +app = FastAPI() +router_a = APIRouter() +router_a_a = APIRouter() +router_a_b_override = APIRouter() # Overrides default class +router_b_override = APIRouter() # Overrides default class +router_b_a = APIRouter() +router_b_a_c_override = APIRouter() # Overrides default class again + + +@app.get("/") +def get_root(): + return {"msg": "Hello World"} + + +@app.get("/override", response_class=PlainTextResponse) +def get_path_override(): + return "Hello World" + + +@router_a.get("/") +def get_a(): + return {"msg": "Hello A"} + + +@router_a.get("/override", response_class=PlainTextResponse) +def get_a_path_override(): + return "Hello A" + + +@router_a_a.get("/") +def get_a_a(): + return {"msg": "Hello A A"} + + +@router_a_a.get("/override", response_class=PlainTextResponse) +def get_a_a_path_override(): + return "Hello A A" + + +@router_a_b_override.get("/") +def get_a_b(): + return "Hello A B" + + +@router_a_b_override.get("/override", response_class=HTMLResponse) +def get_a_b_path_override(): + return "Hello A B" + + +@router_b_override.get("/") +def get_b(): + return "Hello B" + + +@router_b_override.get("/override", response_class=HTMLResponse) +def get_b_path_override(): + return "Hello B" + + +@router_b_a.get("/") +def get_b_a(): + return "Hello B A" + + +@router_b_a.get("/override", response_class=HTMLResponse) +def get_b_a_path_override(): + return "Hello B A" + + +@router_b_a_c_override.get("/") +def get_b_a_c(): + return "Hello B A C" + + +@router_b_a_c_override.get("/override", response_class=OverrideResponse) +def get_b_a_c_path_override(): + return {"msg": "Hello B A C"} + + +router_b_a.include_router( + router_b_a_c_override, prefix="/c", default_response_class=HTMLResponse +) +router_b_override.include_router(router_b_a, prefix="/a") +router_a.include_router(router_a_a, prefix="/a") +router_a.include_router( + router_a_b_override, prefix="/b", default_response_class=PlainTextResponse +) +app.include_router(router_a, prefix="/a") +app.include_router( + router_b_override, prefix="/b", default_response_class=PlainTextResponse +) + + +client = TestClient(app) + +json_type = "application/json" +text_type = "text/plain; charset=utf-8" +html_type = "text/html; charset=utf-8" +override_type = "application/x-override" + + +def test_app(): + with client: + response = client.get("/") + assert response.json() == {"msg": "Hello World"} + assert response.headers["content-type"] == json_type + + +def test_app_override(): + with client: + response = client.get("/override") + assert response.content == b"Hello World" + assert response.headers["content-type"] == text_type + + +def test_router_a(): + with client: + response = client.get("/a") + assert response.json() == {"msg": "Hello A"} + assert response.headers["content-type"] == json_type + + +def test_router_a_override(): + with client: + response = client.get("/a/override") + assert response.content == b"Hello A" + assert response.headers["content-type"] == text_type + + +def test_router_a_a(): + with client: + response = client.get("/a/a") + assert response.json() == {"msg": "Hello A A"} + assert response.headers["content-type"] == json_type + + +def test_router_a_a_override(): + with client: + response = client.get("/a/a/override") + assert response.content == b"Hello A A" + assert response.headers["content-type"] == text_type + + +def test_router_a_b(): + with client: + response = client.get("/a/b") + assert response.content == b"Hello A B" + assert response.headers["content-type"] == text_type + + +def test_router_a_b_override(): + with client: + response = client.get("/a/b/override") + assert response.content == b"Hello A B" + assert response.headers["content-type"] == html_type + + +def test_router_b(): + with client: + response = client.get("/b") + assert response.content == b"Hello B" + assert response.headers["content-type"] == text_type + + +def test_router_b_override(): + with client: + response = client.get("/b/override") + assert response.content == b"Hello B" + assert response.headers["content-type"] == html_type + + +def test_router_b_a(): + with client: + response = client.get("/b/a") + assert response.content == b"Hello B A" + assert response.headers["content-type"] == text_type + + +def test_router_b_a_override(): + with client: + response = client.get("/b/a/override") + assert response.content == b"Hello B A" + assert response.headers["content-type"] == html_type + + +def test_router_b_a_c(): + with client: + response = client.get("/b/a/c") + assert response.content == b"Hello B A C" + assert response.headers["content-type"] == html_type + + +def test_router_b_a_c_override(): + with client: + response = client.get("/b/a/c/override") + assert response.json() == {"msg": "Hello B A C"} + assert response.headers["content-type"] == override_type