forked from tiangolo/fastapi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
189 lines (165 loc) · 6.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import functools
import re
import warnings
from dataclasses import is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type, Union, cast
import fastapi
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.fields import FieldInfo, ModelField, UndefinedType
from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
if status_code is None:
return True
current_status_code = int(status_code)
return not (current_status_code < 200 or current_status_code in {204, 304})
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
return definitions
def get_path_param_names(path: str) -> Set[str]:
return set(re.findall("{(.*?)}", path))
def create_response_field(
name: str,
type_: Type[Any],
class_validators: Optional[Dict[str, Validator]] = None,
default: Optional[Any] = None,
required: Union[bool, UndefinedType] = True,
model_config: Type[BaseConfig] = BaseConfig,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
) -> ModelField:
"""
Create a new response field. Raises if type_ is invalid.
"""
class_validators = class_validators or {}
field_info = field_info or FieldInfo()
response_field = functools.partial(
ModelField,
name=name,
type_=type_,
class_validators=class_validators,
default=default,
required=required,
model_config=model_config,
alias=alias,
)
try:
return response_field(field_info=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
)
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[Dict[Type[BaseModel], Type[BaseModel]]] = None,
) -> ModelField:
# _cloned_types has already cloned types, to support recursive models
if cloned_types is None:
cloned_types = dict()
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f, cloned_types=cloned_types
)
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
new_field.class_validators = field.class_validators
new_field.default = field.default
new_field.required = field.required
new_field.model_config = field.model_config
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none
new_field.validate_always = field.validate_always
if field.sub_fields:
new_field.sub_fields = [
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields
]
if field.key_field:
new_field.key_field = create_cloned_field(
field.key_field, cloned_types=cloned_types
)
new_field.validators = field.validators
new_field.pre_validators = field.pre_validators
new_field.post_validators = field.post_validators
new_field.parse_json = field.parse_json
new_field.shape = field.shape
new_field.populate_validators()
return new_field
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
stacklevel=2,
)
operation_id = name + path
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
operation_id = operation_id + "_" + method.lower()
return operation_id
def generate_unique_id(route: "APIRoute") -> str:
operation_id = route.name + route.path_format
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
assert route.methods
operation_id = operation_id + "_" + list(route.methods)[0].lower()
return operation_id
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
for key, value in update_dict.items():
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(value, dict)
):
deep_dict_update(main_dict[key], value)
elif (
key in main_dict
and isinstance(main_dict[key], list)
and isinstance(update_dict[key], list)
):
main_dict[key] = main_dict[key] + update_dict[key]
else:
main_dict[key] = value
def get_value_or_default(
first_item: Union[DefaultPlaceholder, DefaultType],
*extra_items: Union[DefaultPlaceholder, DefaultType],
) -> Union[DefaultPlaceholder, DefaultType]:
"""
Pass items or `DefaultPlaceholder`s by descending priority.
The first one to _not_ be a `DefaultPlaceholder` will be returned.
Otherwise, the first item (a `DefaultPlaceholder`) will be returned.
"""
items = (first_item,) + extra_items
for item in items:
if not isinstance(item, DefaultPlaceholder):
return item
return first_item