Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for function return type annotations to declare the response_model #1436

Merged
merged 9 commits into from Jan 7, 2023
20 changes: 10 additions & 10 deletions fastapi/applications.py
Expand Up @@ -274,7 +274,7 @@ def add_api_route(
path: str,
endpoint: Callable[..., Coroutine[Any, Any, Response]],
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -332,7 +332,7 @@ def api_route(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -435,7 +435,7 @@ def get(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -490,7 +490,7 @@ def put(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -545,7 +545,7 @@ def post(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -600,7 +600,7 @@ def delete(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -655,7 +655,7 @@ def options(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -710,7 +710,7 @@ def head(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -765,7 +765,7 @@ def patch(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down Expand Up @@ -820,7 +820,7 @@ def trace(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
Expand Down
16 changes: 13 additions & 3 deletions fastapi/dependencies/utils.py
Expand Up @@ -253,22 +253,32 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
annotation=get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature


def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation


def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
annotation = signature.return_annotation

if annotation is inspect.Signature.empty:
return None

globalns = getattr(call, "__globals__", {})
return get_typed_annotation(annotation, globalns)


def get_dependant(
*,
path: str,
Expand Down
25 changes: 14 additions & 11 deletions fastapi/routing.py
Expand Up @@ -26,6 +26,7 @@
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
get_typed_return_annotation,
solve_dependencies,
)
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
Expand Down Expand Up @@ -323,7 +324,7 @@ def __init__(
path: str,
endpoint: Callable[..., Any],
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -354,6 +355,8 @@ def __init__(
) -> None:
self.path = path
self.endpoint = endpoint
if isinstance(response_model, DefaultPlaceholder):
response_model = get_typed_return_annotation(endpoint)
self.response_model = response_model
self.summary = summary
self.response_description = response_description
Expand Down Expand Up @@ -519,7 +522,7 @@ def add_api_route(
path: str,
endpoint: Callable[..., Any],
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -600,7 +603,7 @@ def api_route(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -795,7 +798,7 @@ def get(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -851,7 +854,7 @@ def put(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -907,7 +910,7 @@ def post(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -963,7 +966,7 @@ def delete(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -1019,7 +1022,7 @@ def options(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -1075,7 +1078,7 @@ def head(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -1131,7 +1134,7 @@ def patch(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down Expand Up @@ -1187,7 +1190,7 @@ def trace(
self,
path: str,
*,
response_model: Any = None,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
Expand Down
1 change: 1 addition & 0 deletions tests/test_reponse_set_reponse_code_empty.py
Expand Up @@ -9,6 +9,7 @@
@app.delete(
"/{id}",
status_code=204,
response_model=None,
)
async def delete_deployment(
id: int,
Expand Down
141 changes: 141 additions & 0 deletions tests/test_response_model_as_return_annotation.py
@@ -0,0 +1,141 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel


class ModelOne(BaseModel):
name: str


class ModelTwo(BaseModel):
surname: str


app = FastAPI()


@app.get("/valid1")
def valid1() -> ModelOne:
return ModelOne(name="Test")


@app.get("/valid2", response_model=ModelTwo)
def valid2():
return ModelTwo(surname="Test")


@app.get("/valid3", response_model=ModelTwo)
def valid3() -> ModelOne:
return ModelTwo(surname="Test")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iudeen They actually different. On lines 22-44 there is no return type annotation but 27-29 has it.



@app.get("/valid4")
def valid4() -> "ModelOne":
return ModelOne(name="Test")


openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/valid1": {
"get": {
"summary": "Valid1",
"operationId": "valid1_valid1_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ModelOne"}
}
},
}
},
}
},
"/valid2": {
"get": {
"summary": "Valid2",
"operationId": "valid2_valid2_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ModelTwo"}
}
},
}
},
}
},
"/valid3": {
"get": {
"summary": "Valid3",
"operationId": "valid3_valid3_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ModelTwo"}
}
},
}
},
}
},
"/valid4": {
"get": {
"summary": "Valid4",
"operationId": "valid4_valid4_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/ModelOne"}
}
},
}
},
}
},
},
"components": {
"schemas": {
"ModelOne": {
"title": "ModelOne",
"required": ["name"],
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
},
"ModelTwo": {
"title": "ModelTwo",
"required": ["surname"],
"type": "object",
"properties": {"surname": {"title": "Surname", "type": "string"}},
},
}
},
}

client = TestClient(app)


def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_path_operations():
response = client.get("/valid1")
assert response.status_code == 200, response.text
response = client.get("/valid2")
assert response.status_code == 200, response.text
response = client.get("/valid3")
assert response.status_code == 200, response.text
response = client.get("/valid4")
assert response.status_code == 200, response.text