Skip to content

Commit

Permalink
Add optional field argument to __modify_schema__() (#3434)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
jasujm and samuelcolvin committed Dec 18, 2021
1 parent f36040a commit 63337fb
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 7 deletions.
2 changes: 2 additions & 0 deletions changes/3434-jasujm.md
@@ -0,0 +1,2 @@
When generating field schema, pass optional `field` argument (of type
`pydantic.fields.ModelField`) to `__modify_schema__()` if present
31 changes: 31 additions & 0 deletions docs/examples/schema_with_field.py
@@ -0,0 +1,31 @@
# output-json
from typing import Optional

from pydantic import BaseModel, Field
from pydantic.fields import ModelField


class RestrictedAlphabetStr(str):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value, field: ModelField):
alphabet = field.field_info.extra['alphabet']
if any(c not in alphabet for c in value):
raise ValueError(f'{value!r} is not restricted to {alphabet!r}')
return cls(value)

@classmethod
def __modify_schema__(cls, field_schema, field: Optional[ModelField]):
if field:
alphabet = field.field_info.extra['alphabet']
field_schema['examples'] = [c * 3 for c in alphabet]


class MyModel(BaseModel):
value: RestrictedAlphabetStr = Field(alphabet='ABC')


print(MyModel.schema_json(indent=2))
15 changes: 15 additions & 0 deletions docs/usage/schema.md
Expand Up @@ -150,6 +150,21 @@ For versions of Python prior to 3.9, `typing_extensions.Annotated` can be used.
Custom field types can customise the schema generated for them using the `__modify_schema__` class method;
see [Custom Data Types](types.md#custom-data-types) for more details.

`__modify_schema__` can also take a `field` argument which will have type `Optional[ModelField]`.
*pydantic* will inspect the signature of `__modify_schema__` to determine whether the `field` argument should be
included.

```py
{!.tmp_examples/schema_with_field.py!}
```
_(This script is complete, it should run "as is")_

Outputs:

```json
{!.tmp_examples/schema_with_field.json!}
```

## JSON Schema Types

Types, custom field types, and constraints (like `max_length`) are mapped to the corresponding spec formats in the
Expand Down
30 changes: 23 additions & 7 deletions pydantic/schema.py
Expand Up @@ -90,6 +90,19 @@
TypeModelSet = Set[TypeModelOrEnum]


def _apply_modify_schema(
modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any]
) -> None:
from inspect import signature

sig = signature(modify_schema)
args = set(sig.parameters.keys())
if 'field' in args or 'kwargs' in args:
modify_schema(field_schema, field=field)
else:
modify_schema(field_schema)


def schema(
models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]],
*,
Expand Down Expand Up @@ -335,7 +348,7 @@ def get_field_schema_validations(field: ModelField) -> Dict[str, Any]:
f_schema.update(field.field_info.extra)
modify_schema = getattr(field.outer_type_, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema


Expand Down Expand Up @@ -567,7 +580,7 @@ def field_type_schema(
field_type = field.outer_type_
modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)
return f_schema, definitions, nested_models


Expand All @@ -579,6 +592,7 @@ def model_process_schema(
ref_prefix: Optional[str] = None,
ref_template: str = default_ref_template,
known_models: TypeModelSet = None,
field: Optional[ModelField] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]:
"""
Used by ``model_schema()``, you probably should be using that function.
Expand All @@ -592,7 +606,7 @@ def model_process_schema(
known_models = known_models or set()
if lenient_issubclass(model, Enum):
model = cast(Type[Enum], model)
s = enum_process_schema(model)
s = enum_process_schema(model, field=field)
return s, {}, set()
model = cast(Type['BaseModel'], model)
s = {'title': model.__config__.title or model.__name__}
Expand Down Expand Up @@ -674,7 +688,7 @@ def model_type_schema(
return out_schema, definitions, nested_models


def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:
def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]:
"""
Take a single `enum` and generate its schema.
Expand All @@ -695,7 +709,7 @@ def enum_process_schema(enum: Type[Enum]) -> Dict[str, Any]:

modify_schema = getattr(enum, '__modify_schema__', None)
if modify_schema:
modify_schema(schema_)
_apply_modify_schema(modify_schema, field, schema_)

return schema_

Expand Down Expand Up @@ -871,7 +885,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
enum_name = model_name_map[field_type]
f_schema, schema_overrides = get_field_info_schema(field, schema_overrides)
f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides))
definitions[enum_name] = enum_process_schema(field_type)
definitions[enum_name] = enum_process_schema(field_type, field=field)
elif is_namedtuple(field_type):
sub_schema, *_ = model_process_schema(
field_type.__pydantic_model__,
Expand All @@ -880,6 +894,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
items_schemas = list(sub_schema['properties'].values())
f_schema.update(
Expand All @@ -895,7 +910,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)

modify_schema = getattr(field_type, '__modify_schema__', None)
if modify_schema:
modify_schema(f_schema)
_apply_modify_schema(modify_schema, field, f_schema)

if f_schema:
return f_schema, definitions, nested_models
Expand All @@ -914,6 +929,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
ref_prefix=ref_prefix,
ref_template=ref_template,
known_models=known_models,
field=field,
)
definitions.update(sub_definitions)
definitions[model_name] = sub_schema
Expand Down
22 changes: 22 additions & 0 deletions tests/test_schema.py
Expand Up @@ -32,6 +32,7 @@
from pydantic import BaseModel, Extra, Field, ValidationError, confrozenset, conlist, conset, validator
from pydantic.color import Color
from pydantic.dataclasses import dataclass
from pydantic.fields import ModelField
from pydantic.generics import GenericModel
from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress, IPvAnyInterface, IPvAnyNetwork, NameEmail, stricturl
from pydantic.schema import (
Expand Down Expand Up @@ -2628,6 +2629,27 @@ def resolve(self) -> 'Model': # noqa
}


def test_schema_with_field_parameter():
class RestrictedAlphabetStr(str):
@classmethod
def __modify_schema__(cls, field_schema, field: Optional[ModelField]):
assert isinstance(field, ModelField)
alphabet = field.field_info.extra['alphabet']
field_schema['examples'] = [c * 3 for c in alphabet]

class MyModel(BaseModel):
value: RestrictedAlphabetStr = Field(alphabet='ABC')

assert MyModel.schema() == {
'title': 'MyModel',
'type': 'object',
'properties': {
'value': {'title': 'Value', 'alphabet': 'ABC', 'examples': ['AAA', 'BBB', 'CCC'], 'type': 'string'}
},
'required': ['value'],
}


def test_discriminated_union():
class BlackCat(BaseModel):
pet_type: Literal['cat']
Expand Down

0 comments on commit 63337fb

Please sign in to comment.