Skip to content

Commit

Permalink
✨ Add support for specifying a default_response_class (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
toppk authored and tiangolo committed Sep 29, 2019
1 parent 0c67022 commit f803c77
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 39 deletions.
1 change: 1 addition & 0 deletions Pipfile
Expand Up @@ -29,6 +29,7 @@ starlette = "==0.12.8"
pydantic = "==0.32.2"
databases = {extras = ["sqlite"],version = "*"}
hypercorn = "*"
orjson = "*"

[requires]
python_version = "3.6"
Expand Down
45 changes: 25 additions & 20 deletions fastapi/applications.py
Expand Up @@ -33,11 +33,13 @@ def __init__(
version: str = "0.1.0",
openapi_url: Optional[str] = "/openapi.json",
openapi_prefix: str = "",
default_response_class: Type[Response] = JSONResponse,
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
**extra: Dict[str, Any],
) -> None:
self.default_response_class = default_response_class
self._debug = debug
self.router: routing.APIRouter = routing.APIRouter(
routes, dependency_overrides_provider=self
Expand Down Expand Up @@ -144,7 +146,7 @@ def add_api_route(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> None:
self.router.add_api_route(
Expand All @@ -166,7 +168,7 @@ def add_api_route(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -190,7 +192,7 @@ def api_route(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
Expand All @@ -213,7 +215,7 @@ def decorator(func: Callable) -> Callable:
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
return func
Expand All @@ -240,13 +242,16 @@ def include_router(
tags: List[str] = None,
dependencies: Sequence[Depends] = None,
responses: Dict[Union[int, str], Dict[str, Any]] = None,
default_response_class: Optional[Type[Response]] = None,
) -> None:
self.router.include_router(
router,
prefix=prefix,
tags=tags,
dependencies=dependencies,
responses=responses or {},
default_response_class=default_response_class
or self.default_response_class,
)

def get(
Expand All @@ -268,7 +273,7 @@ def get(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.get(
Expand All @@ -288,7 +293,7 @@ def get(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -311,7 +316,7 @@ def put(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.put(
Expand All @@ -331,7 +336,7 @@ def put(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -354,7 +359,7 @@ def post(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.post(
Expand All @@ -374,7 +379,7 @@ def post(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -397,7 +402,7 @@ def delete(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.delete(
Expand All @@ -417,7 +422,7 @@ def delete(
operation_id=operation_id,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -440,7 +445,7 @@ def options(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.options(
Expand All @@ -460,7 +465,7 @@ def options(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -483,7 +488,7 @@ def head(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.head(
Expand All @@ -503,7 +508,7 @@ def head(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -526,7 +531,7 @@ def patch(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.patch(
Expand All @@ -546,7 +551,7 @@ def patch(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)

Expand All @@ -569,7 +574,7 @@ def trace(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.router.trace(
Expand All @@ -589,6 +594,6 @@ def trace(
response_model_by_alias=response_model_by_alias,
response_model_skip_defaults=response_model_skip_defaults,
include_in_schema=include_in_schema,
response_class=response_class,
response_class=response_class or self.default_response_class,
name=name,
)
8 changes: 6 additions & 2 deletions fastapi/openapi/utils.py
Expand Up @@ -151,6 +151,10 @@ def get_openapi_path(
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
assert (
route.response_class and route.response_class.media_type
), "A response class with media_type is needed to generate OpenAPI"
route_response_media_type: str = route.response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
Expand Down Expand Up @@ -185,7 +189,7 @@ def get_openapi_path(
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
response.setdefault("content", {}).setdefault(
route.response_class.media_type, {}
route_response_media_type, {}
)["schema"] = response_schema
status_text: Optional[str] = status_code_ranges.get(
str(additional_status_code).upper()
Expand Down Expand Up @@ -213,7 +217,7 @@ def get_openapi_path(
] = route.response_description
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(route.response_class.media_type, {})[
).setdefault("content", {}).setdefault(route_response_media_type, {})[
"schema"
] = response_schema

Expand Down
31 changes: 14 additions & 17 deletions fastapi/routing.py
Expand Up @@ -200,7 +200,7 @@ def __init__(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Optional[Type[Response]] = None,
dependency_overrides_provider: Any = None,
) -> None:
self.path = path
Expand All @@ -215,9 +215,6 @@ def __init__(
)
self.response_model = response_model
if self.response_model:
assert lenient_issubclass(
response_class, JSONResponse
), "To declare a type the response must be a JSON response"
response_name = "Response_" + self.unique_id
self.response_field: Optional[Field] = Field(
name=response_name,
Expand Down Expand Up @@ -299,7 +296,7 @@ def __init__(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
response_class=self.response_class,
response_class=self.response_class or JSONResponse,
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
Expand Down Expand Up @@ -346,7 +343,7 @@ def add_api_route(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> None:
route = self.route_class(
Expand Down Expand Up @@ -394,7 +391,7 @@ def api_route(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
Expand Down Expand Up @@ -445,6 +442,7 @@ def include_router(
tags: List[str] = None,
dependencies: Sequence[params.Depends] = None,
responses: Dict[Union[int, str], Dict[str, Any]] = None,
default_response_class: Optional[Type[Response]] = None,
) -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
Expand Down Expand Up @@ -484,7 +482,7 @@ def include_router(
response_model_by_alias=route.response_model_by_alias,
response_model_skip_defaults=route.response_model_skip_defaults,
include_in_schema=route.include_in_schema,
response_class=route.response_class,
response_class=route.response_class or default_response_class,
name=route.name,
)
elif isinstance(route, routing.Route):
Expand Down Expand Up @@ -523,10 +521,9 @@ def get(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:

return self.api_route(
path=path,
response_model=response_model,
Expand Down Expand Up @@ -568,7 +565,7 @@ def put(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -612,7 +609,7 @@ def post(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -656,7 +653,7 @@ def delete(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -700,7 +697,7 @@ def options(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -744,7 +741,7 @@ def head(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -788,7 +785,7 @@ def patch(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down Expand Up @@ -832,7 +829,7 @@ def trace(
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
response_class: Type[Response] = None,
name: str = None,
) -> Callable:
return self.api_route(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -39,6 +39,7 @@ test = [
"email_validator",
"sqlalchemy",
"databases[sqlite]",
"orjson"
]
doc = [
"mkdocs",
Expand Down

0 comments on commit f803c77

Please sign in to comment.