diff --git a/changes/1173-calvinwyoung.md b/changes/1173-calvinwyoung.md new file mode 100644 index 0000000000..c39a5f3359 --- /dev/null +++ b/changes/1173-calvinwyoung.md @@ -0,0 +1,3 @@ +Updates OpenAPI schema generation to output all enums as separate models. +Instead of inlining the enum values in the model schema, models now use a `$ref` +property to point to the enum definition. \ No newline at end of file diff --git a/pydantic/schema.py b/pydantic/schema.py index 8157bacff7..3ca516a906 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -11,6 +11,7 @@ Callable, Dict, FrozenSet, + Iterable, List, Optional, Sequence, @@ -62,6 +63,10 @@ default_prefix = '#/definitions/' +TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]] +TypeModelSet = Set[TypeModelOrEnum] + + def schema( models: Sequence[Union[Type['BaseModel'], Type['DataclassType']]], *, @@ -145,9 +150,9 @@ def field_schema( field: ModelField, *, by_alias: bool = True, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']] = None, + known_models: TypeModelSet = None, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ Process a Pydantic field and return a tuple with a JSON Schema for it as the first item. @@ -240,7 +245,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: return f_schema -def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['BaseModel'], str]: +def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]: """ Process a set of models and generate unique names for them to be used as keys in the JSON Schema definitions. By default the names are the same as the class name. But if two models in different Python @@ -253,8 +258,7 @@ def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['Base name_model_map = {} conflicting_names: Set[str] = set() for model in unique_models: - model_name = model.__name__ - model_name = re.sub(r'[^a-zA-Z0-9.\-_]', '_', model_name) + model_name = normalize_name(model.__name__) if model_name in conflicting_names: model_name = get_long_model_name(model) name_model_map[model_name] = model @@ -268,9 +272,7 @@ def get_model_name_map(unique_models: Set[Type['BaseModel']]) -> Dict[Type['Base return {v: k for k, v in name_model_map.items()} -def get_flat_models_from_model( - model: Type['BaseModel'], known_models: Set[Type['BaseModel']] = None -) -> Set[Type['BaseModel']]: +def get_flat_models_from_model(model: Type['BaseModel'], known_models: TypeModelSet = None) -> TypeModelSet: """ Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also @@ -282,7 +284,7 @@ def get_flat_models_from_model( :return: a set with the initial model and all its sub-models """ known_models = known_models or set() - flat_models: Set[Type['BaseModel']] = set() + flat_models: TypeModelSet = set() flat_models.add(model) known_models |= flat_models fields = cast(Sequence[ModelField], model.__fields__.values()) @@ -290,7 +292,7 @@ def get_flat_models_from_model( return flat_models -def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseModel']]) -> Set[Type['BaseModel']]: +def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> TypeModelSet: """ Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a sublcass of BaseModel (so, it could be a submodel), and generate a set with its model and all the sub-models in the tree. @@ -304,7 +306,7 @@ def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseMo """ from .main import BaseModel # noqa: F811 - flat_models: Set[Type[BaseModel]] = set() + flat_models: TypeModelSet = set() # Handle dataclass-based models field_type = field.type_ if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): @@ -313,12 +315,12 @@ def get_flat_models_from_field(field: ModelField, known_models: Set[Type['BaseMo flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models) elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: flat_models |= get_flat_models_from_model(field_type, known_models=known_models) + elif lenient_issubclass(field_type, Enum): + flat_models.add(field_type) return flat_models -def get_flat_models_from_fields( - fields: Sequence[ModelField], known_models: Set[Type['BaseModel']] -) -> Set[Type['BaseModel']]: +def get_flat_models_from_fields(fields: Sequence[ModelField], known_models: TypeModelSet) -> TypeModelSet: """ Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as sublcasses of ``BaseModel`` (so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree. @@ -330,25 +332,25 @@ def get_flat_models_from_fields( :param known_models: used to solve circular references :return: a set with any model declared in the fields, and all their sub-models """ - flat_models: Set[Type['BaseModel']] = set() + flat_models: TypeModelSet = set() for field in fields: flat_models |= get_flat_models_from_field(field, known_models=known_models) return flat_models -def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> Set[Type['BaseModel']]: +def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> TypeModelSet: """ Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. """ - flat_models: Set[Type['BaseModel']] = set() + flat_models: TypeModelSet = set() for model in models: flat_models |= get_flat_models_from_model(model) return flat_models -def get_long_model_name(model: Type['BaseModel']) -> str: +def get_long_model_name(model: TypeModelOrEnum) -> str: return f'{model.__module__}__{model.__name__}'.replace('.', '__') @@ -356,10 +358,10 @@ def field_type_schema( field: ModelField, *, by_alias: bool, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], schema_overrides: bool = False, ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']], + known_models: TypeModelSet, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ Used by ``field_schema()``, you probably should be using that function. @@ -432,12 +434,12 @@ def field_type_schema( def model_process_schema( - model: Type['BaseModel'], + model: TypeModelOrEnum, *, by_alias: bool = True, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']] = None, + known_models: TypeModelSet = None, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ Used by ``model_schema()``, you probably should be using that function. @@ -450,6 +452,11 @@ def model_process_schema( ref_prefix = ref_prefix or default_prefix known_models = known_models or set() + if lenient_issubclass(model, Enum): + model = cast(Type[Enum], model) + s = enum_process_schema(model) + return s, {}, set() + model = cast(Type['BaseModel'], model) s = {'title': model.__config__.title or model.__name__} doc = getdoc(model) if doc: @@ -474,9 +481,9 @@ def model_type_schema( model: Type['BaseModel'], *, by_alias: bool, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']], + known_models: TypeModelSet, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ You probably should be using ``model_schema()``, this function is indirectly used by that function. @@ -519,14 +526,36 @@ def model_type_schema( return out_schema, definitions, nested_models +def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]: + """ + Take a single `enum` and generate its schema. + + This is similar to the `model_process_schema` function, but applies to ``Enum`` objects. + """ + from inspect import getdoc + + schema: Dict[str, Any] = { + 'title': enum.__name__, + # Python assigns all enums a default docstring value of 'An enumeration', so + # all enums will have a description field even if not explicitly provided. + 'description': getdoc(enum), + # Add enum values and the enum field type to the schema. + 'enum': [item.value for item in cast(Iterable[Enum], enum)], + } + + add_field_type_to_schema(enum, schema) + + return schema + + def field_singleton_sub_fields_schema( sub_fields: Sequence[ModelField], *, by_alias: bool, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], schema_overrides: bool = False, ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']], + known_models: TypeModelSet, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ This function is indirectly used by ``field_schema()``, you probably should be using that function. @@ -599,14 +628,27 @@ def field_singleton_sub_fields_schema( json_scheme = {'type': 'string', 'format': 'json-string'} +def add_field_type_to_schema(field_type: Any, schema: Dict[str, Any]) -> None: + """ + Update the given `schema` with the type-specific metadata for the given `field_type`. + + This function looks through `field_class_to_schema` for a class that matches the given `field_type`, + and then modifies the given `schema` with the information from that type. + """ + for type_, t_schema in field_class_to_schema: + if issubclass(field_type, type_): + schema.update(t_schema) + break + + def field_singleton_schema( # noqa: C901 (ignore complexity) field: ModelField, *, by_alias: bool, - model_name_map: Dict[Type['BaseModel'], str], + model_name_map: Dict[TypeModelOrEnum, str], schema_overrides: bool = False, ref_prefix: Optional[str] = None, - known_models: Set[Type['BaseModel']], + known_models: TypeModelSet, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ This function is indirectly used by ``field_schema()``, you should probably be using that function. @@ -649,14 +691,12 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) field_type = literal_value.__class__ f_schema['const'] = literal_value - if issubclass(field_type, Enum): - f_schema.update({'enum': [item.value for item in field_type]}) - # Don't return immediately, to allow adding specific types - - for type_, t_schema in field_class_to_schema: - if issubclass(field_type, type_): - f_schema.update(t_schema) - break + if lenient_issubclass(field_type, Enum): + enum_name = normalize_name(field_type.__name__) + f_schema = {'$ref': ref_prefix + enum_name} + definitions[enum_name] = enum_process_schema(field_type) + else: + add_field_type_to_schema(field_type, f_schema) modify_schema = getattr(field_type, '__modify_schema__', None) if modify_schema: @@ -815,6 +855,13 @@ def go(type_: Any) -> Type[Any]: return ans +def normalize_name(name: str) -> str: + """ + Normalizes the given name. This can be applied to either a model *or* enum. + """ + return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name) + + class SkipField(Exception): """ Utility exception used to exclude fields from schema. diff --git a/tests/test_schema.py b/tests/test_schema.py index 559844209e..ce447d5ec7 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -20,6 +20,7 @@ get_flat_models_from_model, get_flat_models_from_models, get_model_name_map, + model_process_schema, model_schema, schema, ) @@ -205,14 +206,19 @@ class Model(BaseModel): spam: SpamEnum = Field(None) assert Model.schema() == { - 'type': 'object', 'title': 'Model', + 'type': 'object', 'properties': { - 'foo': {'title': 'Foo', 'enum': ['f', 'b']}, - 'bar': {'type': 'integer', 'title': 'Bar', 'enum': [1, 2]}, - 'spam': {'type': 'string', 'title': 'Spam', 'enum': ['f', 'b']}, + 'foo': {'$ref': '#/definitions/FooEnum'}, + 'bar': {'$ref': '#/definitions/BarEnum'}, + 'spam': {'$ref': '#/definitions/SpamEnum'}, }, 'required': ['foo', 'bar'], + 'definitions': { + 'FooEnum': {'title': 'FooEnum', 'description': 'An enumeration.', 'enum': ['f', 'b']}, + 'BarEnum': {'title': 'BarEnum', 'description': 'An enumeration.', 'type': 'integer', 'enum': [1, 2]}, + 'SpamEnum': {'title': 'SpamEnum', 'description': 'An enumeration.', 'type': 'string', 'enum': ['f', 'b']}, + }, } @@ -1769,6 +1775,8 @@ class Model: def test_schema_attributes(): class ExampleEnum(Enum): + """This is a test description.""" + gt = 'GT' lt = 'LT' ge = 'GE' @@ -1783,11 +1791,28 @@ class Example(BaseModel): assert Example.schema() == { 'title': 'Example', 'type': 'object', - 'properties': {'example': {'title': 'Example', 'enum': ['GT', 'LT', 'GE', 'LE', 'ML', 'MO', 'RE']}}, + 'properties': {'example': {'$ref': '#/definitions/ExampleEnum'}}, 'required': ['example'], + 'definitions': { + 'ExampleEnum': { + 'title': 'ExampleEnum', + 'description': 'This is a test description.', + 'enum': ['GT', 'LT', 'GE', 'LE', 'ML', 'MO', 'RE'], + } + }, } +def test_model_process_schema_enum(): + class SpamEnum(str, Enum): + foo = 'f' + bar = 'b' + + model_schema, _, _ = model_process_schema(SpamEnum, model_name_map={}) + print(model_schema) + assert model_schema == {'title': 'SpamEnum', 'description': 'An enumeration.', 'type': 'string', 'enum': ['f', 'b']} + + def test_path_modify_schema(): class MyPath(Path): @classmethod