From add958c16f62f1ef9babc8bd4bd6a093fd0cdcf4 Mon Sep 17 00:00:00 2001 From: Jaakko Moisio Date: Sun, 21 Nov 2021 21:52:13 +0200 Subject: [PATCH] Add optional `field` argument to `__modify_schema__()` --- changes/3434-jasujm.md | 2 ++ docs/usage/schema.md | 5 +++++ pydantic/schema.py | 28 ++++++++++++++++++++++------ tests/test_schema.py | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) create mode 100644 changes/3434-jasujm.md diff --git a/changes/3434-jasujm.md b/changes/3434-jasujm.md new file mode 100644 index 00000000000..bd93d666f04 --- /dev/null +++ b/changes/3434-jasujm.md @@ -0,0 +1,2 @@ +When generating field schema, pass optional `field` argument (of type +`pydantic.fields.ModelField`) to `__modify_schema__()` if present diff --git a/docs/usage/schema.md b/docs/usage/schema.md index 9c341961109..567e37d94ba 100644 --- a/docs/usage/schema.md +++ b/docs/usage/schema.md @@ -131,6 +131,11 @@ For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used. Custom field types can customise the schema generated for them using the `__modify_schema__` class method; see [Custom Data Types](types.md#custom-data-types) for more details. +You can also add any of the following arguments to the signature to use them in the implementation: + +* `field`: the field whose schema is customized. Type of object is `pydantic.fields.ModelField`. +* `**kwargs`: if provided, the above argument will be provided in they `kwargs` dictionary + ## JSON Schema Types Types, custom field types, and constraints (like `max_length`) are mapped to the corresponding spec formats in the diff --git a/pydantic/schema.py b/pydantic/schema.py index 581b248d370..51bee5028f8 100644 --- a/pydantic/schema.py +++ b/pydantic/schema.py @@ -87,6 +87,19 @@ TypeModelSet = Set[TypeModelOrEnum] +def _apply_modify_schema( + modify_schema: Callable[..., Any], field: Optional[ModelField], field_schema: Dict[str, Any] +) -> None: + from inspect import signature + + sig = signature(modify_schema) + args = set(sig.parameters.keys()) + if 'field' in args or 'kwargs' in args: + modify_schema(field_schema, field=field) + else: + modify_schema(field_schema) + + def schema( models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]], *, @@ -302,7 +315,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: f_schema.update(field.field_info.extra) modify_schema = getattr(field.outer_type_, '__modify_schema__', None) if modify_schema: - modify_schema(f_schema) + _apply_modify_schema(modify_schema, field, f_schema) return f_schema @@ -530,7 +543,7 @@ def field_type_schema( field_type = field.outer_type_ modify_schema = getattr(field_type, '__modify_schema__', None) if modify_schema: - modify_schema(f_schema) + _apply_modify_schema(modify_schema, field, f_schema) return f_schema, definitions, nested_models @@ -542,6 +555,7 @@ def model_process_schema( ref_prefix: Optional[str] = None, ref_template: str = default_ref_template, known_models: TypeModelSet = None, + field: ModelField = None, ) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: """ Used by ``model_schema()``, you probably should be using that function. @@ -555,7 +569,7 @@ def model_process_schema( known_models = known_models or set() if lenient_issubclass(model, Enum): model = cast(Type[Enum], model) - s = enum_process_schema(model) + s = enum_process_schema(model, field=field) return s, {}, set() model = cast(Type['BaseModel'], model) s = {'title': model.__config__.title or model.__name__} @@ -637,7 +651,7 @@ def model_type_schema( return out_schema, definitions, nested_models -def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]: +def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]: """ Take a single `enum` and generate its schema. @@ -658,7 +672,7 @@ def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]: modify_schema = getattr(enum, '__modify_schema__', None) if modify_schema: - modify_schema(schema_) + _apply_modify_schema(modify_schema, field, schema_) return schema_ @@ -834,7 +848,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) enum_name = model_name_map[field_type] f_schema, schema_overrides = get_field_info_schema(field) f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides)) - definitions[enum_name] = enum_process_schema(field_type) + definitions[enum_name] = enum_process_schema(field_type, field=field) elif is_namedtuple(field_type): sub_schema, *_ = model_process_schema( field_type.__pydantic_model__, @@ -843,6 +857,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) ref_prefix=ref_prefix, ref_template=ref_template, known_models=known_models, + field=field, ) f_schema.update({'type': 'array', 'items': list(sub_schema['properties'].values())}) elif not hasattr(field_type, '__pydantic_model__'): @@ -869,6 +884,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) ref_prefix=ref_prefix, ref_template=ref_template, known_models=known_models, + field=field, ) definitions.update(sub_definitions) definitions[model_name] = sub_schema diff --git a/tests/test_schema.py b/tests/test_schema.py index cd4e2e4f8d7..f87a40a6917 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2562,3 +2562,43 @@ def resolve(self) -> 'Model': # noqa }, '$ref': '#/definitions/Model', } + + +@pytest.mark.skipif( + sys.version_info < (3, 7), reason='schema generation for generic fields is not available in python < 3.7' +) +def test_schema_for_generic_field_with_field_parameter(): + T = TypeVar('T') + + class GenModel(Generic[T]): + @classmethod + def __modify_schema__(cls, field_schema, field): + field_schema['title'] = f'GenModel with {field.sub_fields[0].type_.__name__}' + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v: Any): + return cls(*v) + + class Model1(BaseModel): + data: GenModel[str] + + assert Model1.schema() == { + 'properties': {'data': {'allOf': [{'type': 'string'}], 'title': 'GenModel with str'}}, + 'required': ['data'], + 'title': 'Model1', + 'type': 'object', + } + + class Model2(BaseModel): + data: GenModel[int] + + assert Model2.schema() == { + 'properties': {'data': {'allOf': [{'type': 'integer'}], 'title': 'GenModel with int'}}, + 'required': ['data'], + 'title': 'Model2', + 'type': 'object', + }