Skip to content

Commit

Permalink
Enums as separate models (pydantic#1432)
Browse files Browse the repository at this point in the history
* Updates schema() to generate enums as separate models

* Fixes mypy annotations

* Adds changes file

* Fixes comment

* Removes unused import

* Fixes test case

* Fixes missing partial branch in test coverage

* Resolves PR comments

* 🐛 Include enums in flat model schema handling

as they now have independent schemas, they kinda behave like top-level models, and should be taken into account for top level definitions

* ✅ Add test for coverage

* 🐛 Use Type[Enum] as type for consistency

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
  • Loading branch information
calvinwyoung and tiangolo committed May 23, 2020
1 parent 913025a commit 5195e55
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 42 deletions.
3 changes: 3 additions & 0 deletions changes/1173-calvinwyoung.md
@@ -0,0 +1,3 @@
Updates OpenAPI schema generation to output all enums as separate models.
Instead of inlining the enum values in the model schema, models now use a `$ref`
property to point to the enum definition.
121 changes: 84 additions & 37 deletions pydantic/schema.py
Expand Up @@ -11,6 +11,7 @@
Callable,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -62,6 +63,10 @@
default_prefix = '#/definitions/'


TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]]
TypeModelSet = Set[TypeModelOrEnum]


def schema(
models: Sequence[Union[Type['BaseModel'], Type['DataclassType']]],
*,
Expand Down Expand Up @@ -145,9 +150,9 @@ def field_schema(
field: ModelField,
*,
by_alias: bool = True,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']] = None,
known_models: TypeModelSet = None,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Process a Pydantic field and return a tuple with a JSON Schema for it as the first item.
Expand Down Expand Up @@ -240,7 +245,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]:
return f_schema


def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['BaseModel'], str]:
def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]:
"""
Process a set of models and generate unique names for them to be used as keys in the JSON Schema
definitions. By default the names are the same as the class name. But if two models in different Python
Expand All @@ -253,8 +258,7 @@ def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['Base
name_model_map = {}
conflicting_names: Set[str] = set()
for model in unique_models:
model_name = model.__name__
model_name = re.sub(r'[^a-zA-Z0-9.\-_]', '_', model_name)
model_name = normalize_name(model.__name__)
if model_name in conflicting_names:
model_name = get_long_model_name(model)
name_model_map[model_name] = model
Expand All @@ -268,9 +272,7 @@ def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['Base
return {v: k for k, v in name_model_map.items()}


def get_flat_models_from_model(
model: Type['BaseModel'], known_models: Set[Type['BaseModel']] = None
) -> Set[Type['BaseModel']]:
def get_flat_models_from_model(model: Type['BaseModel'], known_models: TypeModelSet = None) -> TypeModelSet:
"""
Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass
model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also
Expand All @@ -282,15 +284,15 @@ def get_flat_models_from_model(
:return: a set with the initial model and all its sub-models
"""
known_models = known_models or set()
flat_models: Set[Type['BaseModel']] = set()
flat_models: TypeModelSet = set()
flat_models.add(model)
known_models |= flat_models
fields = cast(Sequence[ModelField], model.__fields__.values())
flat_models |= get_flat_models_from_fields(fields, known_models=known_models)
return flat_models


def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseModel']]) -> Set[Type['BaseModel']]:
def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> TypeModelSet:
"""
Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a sublcass of BaseModel
(so, it could be a submodel), and generate a set with its model and all the sub-models in the tree.
Expand All @@ -304,7 +306,7 @@ def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseMo
"""
from .main import BaseModel # noqa: F811

flat_models: Set[Type[BaseModel]] = set()
flat_models: TypeModelSet = set()
# Handle dataclass-based models
field_type = field.type_
if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel):
Expand All @@ -313,12 +315,12 @@ def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseMo
flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models)
elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models:
flat_models |= get_flat_models_from_model(field_type, known_models=known_models)
elif lenient_issubclass(field_type, Enum):
flat_models.add(field_type)
return flat_models


def get_flat_models_from_fields(
fields: Sequence[ModelField], known_models: Set[Type['BaseModel']]
) -> Set[Type['BaseModel']]:
def get_flat_models_from_fields(fields: Sequence[ModelField], known_models: TypeModelSet) -> TypeModelSet:
"""
Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as sublcasses of ``BaseModel``
(so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree.
Expand All @@ -330,36 +332,36 @@ def get_flat_models_from_fields(
:param known_models: used to solve circular references
:return: a set with any model declared in the fields, and all their sub-models
"""
flat_models: Set[Type['BaseModel']] = set()
flat_models: TypeModelSet = set()
for field in fields:
flat_models |= get_flat_models_from_field(field, known_models=known_models)
return flat_models


def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> Set[Type['BaseModel']]:
def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> TypeModelSet:
"""
Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass
a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has
a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``.
"""
flat_models: Set[Type['BaseModel']] = set()
flat_models: TypeModelSet = set()
for model in models:
flat_models |= get_flat_models_from_model(model)
return flat_models


def get_long_model_name(model: Type['BaseModel']) -> str:
def get_long_model_name(model: TypeModelOrEnum) -> str:
return f'{model.__module__}__{model.__name__}'.replace('.', '__')


def field_type_schema(
field: ModelField,
*,
by_alias: bool,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
schema_overrides: bool = False,
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']],
known_models: TypeModelSet,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Used by ``field_schema()``, you probably should be using that function.
Expand Down Expand Up @@ -432,12 +434,12 @@ def field_type_schema(


def model_process_schema(
model: Type['BaseModel'],
model: TypeModelOrEnum,
*,
by_alias: bool = True,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']] = None,
known_models: TypeModelSet = None,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Used by ``model_schema()``, you probably should be using that function.
Expand All @@ -450,6 +452,11 @@ def model_process_schema(

ref_prefix = ref_prefix or default_prefix
known_models = known_models or set()
if lenient_issubclass(model, Enum):
model = cast(Type[Enum], model)
s = enum_process_schema(model)
return s, {}, set()
model = cast(Type['BaseModel'], model)
s = {'title': model.__config__.title or model.__name__}
doc = getdoc(model)
if doc:
Expand All @@ -474,9 +481,9 @@ def model_type_schema(
model: Type['BaseModel'],
*,
by_alias: bool,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']],
known_models: TypeModelSet,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
You probably should be using ``model_schema()``, this function is indirectly used by that function.
Expand Down Expand Up @@ -519,14 +526,36 @@ def model_type_schema(
return out_schema, definitions, nested_models


def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:
"""
Take a single `enum` and generate its schema.
This is similar to the `model_process_schema` function, but applies to ``Enum`` objects.
"""
from inspect import getdoc

schema: Dict[str, Any] = {
'title': enum.__name__,
# Python assigns all enums a default docstring value of 'An enumeration', so
# all enums will have a description field even if not explicitly provided.
'description': getdoc(enum),
# Add enum values and the enum field type to the schema.
'enum': [item.value for item in cast(Iterable[Enum], enum)],
}

add_field_type_to_schema(enum, schema)

return schema


def field_singleton_sub_fields_schema(
sub_fields: Sequence[ModelField],
*,
by_alias: bool,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
schema_overrides: bool = False,
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']],
known_models: TypeModelSet,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
This function is indirectly used by ``field_schema()``, you probably should be using that function.
Expand Down Expand Up @@ -599,14 +628,27 @@ def field_singleton_sub_fields_schema(
json_scheme = {'type': 'string', 'format': 'json-string'}


def add_field_type_to_schema(field_type: Any, schema: Dict[str, Any]) -> None:
"""
Update the given `schema` with the type-specific metadata for the given `field_type`.
This function looks through `field_class_to_schema` for a class that matches the given `field_type`,
and then modifies the given `schema` with the information from that type.
"""
for type_, t_schema in field_class_to_schema:
if issubclass(field_type, type_):
schema.update(t_schema)
break


def field_singleton_schema( # noqa: C901 (ignore complexity)
field: ModelField,
*,
by_alias: bool,
model_name_map: Dict[Type['BaseModel'], str],
model_name_map: Dict[TypeModelOrEnum, str],
schema_overrides: bool = False,
ref_prefix: Optional[str] = None,
known_models: Set[Type['BaseModel']],
known_models: TypeModelSet,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
This function is indirectly used by ``field_schema()``, you should probably be using that function.
Expand Down Expand Up @@ -649,14 +691,12 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
field_type = literal_value.__class__
f_schema['const'] = literal_value

if issubclass(field_type, Enum):
f_schema.update({'enum': [item.value for item in field_type]})
# Don't return immediately, to allow adding specific types

for type_, t_schema in field_class_to_schema:
if issubclass(field_type, type_):
f_schema.update(t_schema)
break
if lenient_issubclass(field_type, Enum):
enum_name = normalize_name(field_type.__name__)
f_schema = {'$ref': ref_prefix + enum_name}
definitions[enum_name] = enum_process_schema(field_type)
else:
add_field_type_to_schema(field_type, f_schema)

modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
Expand Down Expand Up @@ -815,6 +855,13 @@ def go(type_: Any) -> Type[Any]:
return ans


def normalize_name(name: str) -> str:
"""
Normalizes the given name. This can be applied to either a model *or* enum.
"""
return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name)


class SkipField(Exception):
"""
Utility exception used to exclude fields from schema.
Expand Down
35 changes: 30 additions & 5 deletions tests/test_schema.py
Expand Up @@ -20,6 +20,7 @@
get_flat_models_from_model,
get_flat_models_from_models,
get_model_name_map,
model_process_schema,
model_schema,
schema,
)
Expand Down Expand Up @@ -205,14 +206,19 @@ class Model(BaseModel):
spam: SpamEnum = Field(None)

assert Model.schema() == {
'type': 'object',
'title': 'Model',
'type': 'object',
'properties': {
'foo': {'title': 'Foo', 'enum': ['f', 'b']},
'bar': {'type': 'integer', 'title': 'Bar', 'enum': [1, 2]},
'spam': {'type': 'string', 'title': 'Spam', 'enum': ['f', 'b']},
'foo': {'$ref': '#/definitions/FooEnum'},
'bar': {'$ref': '#/definitions/BarEnum'},
'spam': {'$ref': '#/definitions/SpamEnum'},
},
'required': ['foo', 'bar'],
'definitions': {
'FooEnum': {'title': 'FooEnum', 'description': 'An enumeration.', 'enum': ['f', 'b']},
'BarEnum': {'title': 'BarEnum', 'description': 'An enumeration.', 'type': 'integer', 'enum': [1, 2]},
'SpamEnum': {'title': 'SpamEnum', 'description': 'An enumeration.', 'type': 'string', 'enum': ['f', 'b']},
},
}


Expand Down Expand Up @@ -1769,6 +1775,8 @@ class Model:

def test_schema_attributes():
class ExampleEnum(Enum):
"""This is a test description."""

gt = 'GT'
lt = 'LT'
ge = 'GE'
Expand All @@ -1783,11 +1791,28 @@ class Example(BaseModel):
assert Example.schema() == {
'title': 'Example',
'type': 'object',
'properties': {'example': {'title': 'Example', 'enum': ['GT', 'LT', 'GE', 'LE', 'ML', 'MO', 'RE']}},
'properties': {'example': {'$ref': '#/definitions/ExampleEnum'}},
'required': ['example'],
'definitions': {
'ExampleEnum': {
'title': 'ExampleEnum',
'description': 'This is a test description.',
'enum': ['GT', 'LT', 'GE', 'LE', 'ML', 'MO', 'RE'],
}
},
}


def test_model_process_schema_enum():
class SpamEnum(str, Enum):
foo = 'f'
bar = 'b'

model_schema, _, _ = model_process_schema(SpamEnum, model_name_map={})
print(model_schema)
assert model_schema == {'title': 'SpamEnum', 'description': 'An enumeration.', 'type': 'string', 'enum': ['f', 'b']}


def test_path_modify_schema():
class MyPath(Path):
@classmethod
Expand Down

0 comments on commit 5195e55

Please sign in to comment.