Skip to content

Commit

Permalink
Add optional field argument to __modify_schema__()
Browse files Browse the repository at this point in the history
  • Loading branch information
jasujm committed Nov 27, 2021
1 parent cc1cb48 commit add958c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 6 deletions.
2 changes: 2 additions & 0 deletions changes/3434-jasujm.md
@@ -0,0 +1,2 @@
When generating field schema, pass optional `field` argument (of type
`pydantic.fields.ModelField`) to `__modify_schema__()` if present
5 changes: 5 additions & 0 deletions docs/usage/schema.md
Expand Up @@ -131,6 +131,11 @@ For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used.
Custom field types can customise the schema generated for them using the `__modify_schema__` class method;
see [Custom Data Types](types.md#custom-data-types) for more details.

You can also add any of the following arguments to the signature to use them in the implementation:

* `field`: the field whose schema is customized. Type of object is `pydantic.fields.ModelField`.
* `**kwargs`: if provided, the above argument will be provided in they `kwargs` dictionary

## JSON Schema Types

Types, custom field types, and constraints (like `max_length`) are mapped to the corresponding spec formats in the
Expand Down
28 changes: 22 additions & 6 deletions pydantic/schema.py
Expand Up @@ -87,6 +87,19 @@
TypeModelSet = Set[TypeModelOrEnum]


def _apply_modify_schema(
modify_schema: Callable[..., Any], field: Optional[ModelField], field_schema: Dict[str, Any]
) -> None:
from inspect import signature

sig = signature(modify_schema)
args = set(sig.parameters.keys())
if 'field' in args or 'kwargs' in args:
modify_schema(field_schema, field=field)
else:
modify_schema(field_schema)


def schema(
models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]],
*,
Expand Down Expand Up @@ -302,7 +315,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]:
f_schema.update(field.field_info.extra)
modify_schema = getattr(field.outer_type_, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema


Expand Down Expand Up @@ -530,7 +543,7 @@ def field_type_schema(
field_type = field.outer_type_
modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema, definitions, nested_models


Expand All @@ -542,6 +555,7 @@ def model_process_schema(
ref_prefix: Optional[str] = None,
ref_template: str = default_ref_template,
known_models: TypeModelSet = None,
field: ModelField = None,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Used by ``model_schema()``, you probably should be using that function.
Expand All @@ -555,7 +569,7 @@ def model_process_schema(
known_models = known_models or set()
if lenient_issubclass(model, Enum):
model = cast(Type[Enum], model)
s = enum_process_schema(model)
s = enum_process_schema(model, field=field)
return s, {}, set()
model = cast(Type['BaseModel'], model)
s = {'title': model.__config__.title or model.__name__}
Expand Down Expand Up @@ -637,7 +651,7 @@ def model_type_schema(
return out_schema, definitions, nested_models


def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:
def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]:
"""
Take a single `enum` and generate its schema.
Expand All @@ -658,7 +672,7 @@ def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:

modify_schema = getattr(enum, '__modify_schema__', None)
if modify_schema:
modify_schema(schema_)
_apply_modify_schema(modify_schema, field, schema_)

return schema_

Expand Down Expand Up @@ -834,7 +848,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
enum_name = model_name_map[field_type]
f_schema, schema_overrides = get_field_info_schema(field)
f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides))
definitions[enum_name] = enum_process_schema(field_type)
definitions[enum_name] = enum_process_schema(field_type, field=field)
elif is_namedtuple(field_type):
sub_schema, *_ = model_process_schema(
field_type.__pydantic_model__,
Expand All @@ -843,6 +857,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
f_schema.update({'type': 'array', 'items': list(sub_schema['properties'].values())})
elif not hasattr(field_type, '__pydantic_model__'):
Expand All @@ -869,6 +884,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
definitions.update(sub_definitions)
definitions[model_name] = sub_schema
Expand Down
40 changes: 40 additions & 0 deletions tests/test_schema.py
Expand Up @@ -2562,3 +2562,43 @@ def resolve(self) -> 'Model': # noqa
},
'$ref': '#/definitions/Model',
}


@pytest.mark.skipif(
sys.version_info < (3, 7), reason='schema generation for generic fields is not available in python < 3.7'
)
def test_schema_for_generic_field_with_field_parameter():
T = TypeVar('T')

class GenModel(Generic[T]):
@classmethod
def __modify_schema__(cls, field_schema, field):
field_schema['title'] = f'GenModel with {field.sub_fields[0].type_.__name__}'

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v: Any):
return cls(*v)

class Model1(BaseModel):
data: GenModel[str]

assert Model1.schema() == {
'properties': {'data': {'allOf': [{'type': 'string'}], 'title': 'GenModel with str'}},
'required': ['data'],
'title': 'Model1',
'type': 'object',
}

class Model2(BaseModel):
data: GenModel[int]

assert Model2.schema() == {
'properties': {'data': {'allOf': [{'type': 'integer'}], 'title': 'GenModel with int'}},
'required': ['data'],
'title': 'Model2',
'type': 'object',
}

0 comments on commit add958c

Please sign in to comment.