Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for programmatic title generation #9183

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions pydantic/_internal/_config.py
Expand Up @@ -57,6 +57,8 @@ class ConfigWrapper:
# to construct error `loc`s, default `True`
loc_by_alias: bool
alias_generator: Callable[[str], str] | AliasGenerator | None
class_title_generator: Callable[[str], str] | None
field_title_generator: Callable[[str], str] | None
ignored_types: tuple[type, ...]
allow_inf_nan: bool
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
Expand Down Expand Up @@ -240,6 +242,8 @@ def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
from_attributes=False,
loc_by_alias=True,
alias_generator=None,
class_title_generator=None,
field_title_generator=None,
ignored_types=(),
allow_inf_nan=True,
json_schema_extra=None,
Expand Down
81 changes: 73 additions & 8 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -102,7 +102,6 @@
ModifyCoreSchemaWrapHandler = GetCoreSchemaHandler
GetCoreSchemaFunction = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema]


TUPLE_TYPES: list[type] = [tuple, typing.Tuple]
LIST_TYPES: list[type] = [list, typing.List, collections.abc.MutableSequence]
SET_TYPES: list[type] = [set, typing.Set, collections.abc.MutableSet]
Expand Down Expand Up @@ -202,19 +201,26 @@ def apply_each_item_validators(


def modify_model_json_schema(
schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler, *, cls: Any
schema_or_field: CoreSchemaOrField,
handler: GetJsonSchemaHandler,
*,
cls: Any,
title: str | None = None,
) -> JsonSchemaValue:
"""Add title and description for model-like classes' JSON schema.

Args:
schema_or_field: The schema data to generate a JSON schema from.
handler: The `GetCoreSchemaHandler` instance.
cls: The model-like class.
title: The title to set for the model's schema, defaults to the models name
NeevCohen marked this conversation as resolved.
Show resolved Hide resolved

Returns:
JsonSchemaValue: The updated JSON schema.
"""
from ..dataclasses import is_pydantic_dataclass
from ..main import BaseModel
from ._dataclasses import is_builtin_dataclass

json_schema = handler(schema_or_field)
original_schema = handler.resolve_ref_schema(json_schema)
Expand All @@ -223,10 +229,12 @@ def modify_model_json_schema(
ref = original_schema['$ref']
original_schema.clear()
original_schema['allOf'] = [{'$ref': ref}]
if 'title' not in original_schema:
if title is not None:
original_schema['title'] = title
elif 'title' not in original_schema:
original_schema['title'] = cls.__name__
# BaseModel; don't use cls.__doc__ as it will contain the verbose class signature by default
docstring = None if cls is BaseModel else cls.__doc__
# BaseModel + Dataclass; don't use cls.__doc__ as it will contain the verbose class signature by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

docstring = None if cls is BaseModel or is_builtin_dataclass(cls) or is_pydantic_dataclass(cls) else cls.__doc__
if docstring and 'description' not in original_schema:
original_schema['description'] = inspect.cleandoc(docstring)
return json_schema
Expand Down Expand Up @@ -527,7 +535,8 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
)
config_wrapper = ConfigWrapper(cls.model_config, check=False)
core_config = config_wrapper.core_config(cls)
metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls)])
title = self._get_class_title_from_config(cls, config_wrapper)
metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls, title=title)])

model_validators = decorators.model_validators.values()

Expand Down Expand Up @@ -604,6 +613,26 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
self.defs.definitions[model_ref] = schema
return core_schema.definition_reference_schema(model_ref)

@staticmethod
def _get_class_title_from_config(
cls: type[BaseModel | StandardDataclass], config_wrapper: ConfigWrapper | None = None
) -> str | None:
"""Get the title of a class if `class_title_generator` or `title` are set in the config, else return None"""
if config_wrapper is None:
return None

if config_wrapper.title:
return config_wrapper.title

class_title_generator = config_wrapper.class_title_generator
if class_title_generator:
title = class_title_generator(cls.__name__)
if not isinstance(title, str):
raise TypeError(f'class_title_generator {class_title_generator} must return str, not {title.__class__}')
return title

return None

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
"""Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions`
and return the inner schema.
Expand Down Expand Up @@ -1034,6 +1063,23 @@ def _apply_alias_generator_to_computed_field_info(
if computed_field_info.alias_priority == 1:
computed_field_info.alias = _get_first_non_null(serialization_alias, alias)

@staticmethod
def _apply_field_title_generator_to_field_info(
field_title_generator: Callable[[str], str], field_info: FieldInfo | ComputedFieldInfo, field_name: str
) -> None:
"""Apply a field_title_generator on a FieldInfo or ComputedFieldInfo instance if appropriate
Args:
field_title_generator: A callable that takes a string and returns a string.
field_info: The FieldInfo or ComputedField instance to which the title_generator is (maybe) applied.
field_name: The name of the field from which to generate the title.
"""
if field_info.title_priority is None or field_info.title_priority <= 1 or field_info.title is None:
title = field_title_generator(field_name)
if not isinstance(title, str):
raise TypeError(f'field_title_generator {field_title_generator} must return str, not {title.__class__}')

field_info.title = title

def _common_field_schema( # C901
self, name: str, field_info: FieldInfo, decorators: DecoratorInfos
) -> _CommonField:
Expand Down Expand Up @@ -1105,6 +1151,10 @@ def set_discriminator(schema: CoreSchema) -> CoreSchema:
schema = self._apply_field_serializers(
schema, filter_field_decorator_info_by_field(decorators.field_serializers.values(), name)
)
field_title_generator = field_info.field_title_generator or self._config_wrapper.field_title_generator
if field_title_generator is not None:
self._apply_field_title_generator_to_field_info(field_title_generator, field_info, name)

NeevCohen marked this conversation as resolved.
Show resolved Hide resolved
json_schema_updates = {
'title': field_info.title,
'description': field_info.description,
Expand Down Expand Up @@ -1274,14 +1324,20 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
and field_name in field_docstrings
):
field_info.description = field_docstrings[field_name]
field_title_generator = (
field_info.field_title_generator or self._config_wrapper.field_title_generator
)
if field_title_generator is not None:
self._apply_field_title_generator_to_field_info(field_title_generator, field_info, field_name)
fields[field_name] = self._generate_td_field_schema(
field_name, field_info, decorators, required=required
)

title = self._get_class_title_from_config(typed_dict_cls, self._config_wrapper)
metadata = build_metadata_dict(
js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls
js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls, title=title)],
typed_dict_cls=typed_dict_cls,
)

td_schema = core_schema.typed_dict_schema(
fields,
computed_fields=[
Expand Down Expand Up @@ -1577,6 +1633,11 @@ def _dataclass_schema(
model_validators = decorators.model_validators.values()
inner_schema = apply_model_validators(inner_schema, model_validators, 'inner')

title = self._get_class_title_from_config(dataclass, self._config_wrapper)
metadata = build_metadata_dict(
js_functions=[partial(modify_model_json_schema, cls=dataclass, title=title)]
)

dc_schema = core_schema.dataclass_schema(
dataclass,
inner_schema,
Expand All @@ -1585,6 +1646,7 @@ def _dataclass_schema(
fields=[field.name for field in dataclasses.fields(dataclass)],
slots=has_slots,
config=core_config,
metadata=metadata,
)
schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values())
schema = apply_model_validators(schema, model_validators, 'outer')
Expand Down Expand Up @@ -1706,6 +1768,9 @@ def _computed_field_schema(
self._apply_alias_generator_to_computed_field_info(
alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name
)
field_title_generator = d.info.field_title_generator or self._config_wrapper.field_title_generator
if field_title_generator is not None:
self._apply_field_title_generator_to_field_info(field_title_generator, d.info, d.cls_var_name)

def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(schema)
Expand Down
7 changes: 7 additions & 0 deletions pydantic/config.py
Expand Up @@ -33,11 +33,18 @@ class ConfigDict(TypedDict, total=False):
title: str | None
"""The title for the generated JSON schema, defaults to the model's name"""

class_title_generator: Callable[[str], str] | None
"""A callable that takes a class name and returns the title for it. Defaults to `None`."""

field_title_generator: Callable[[str], str] | None
"""A callable that takes a field name and returns title for it. Defaults to `None`."""

str_to_lower: bool
"""Whether to convert all characters to lowercase for str types. Defaults to `False`."""

str_to_upper: bool
"""Whether to convert all characters to uppercase for str types. Defaults to `False`."""

str_strip_whitespace: bool
"""Whether to strip leading and trailing whitespace for str types."""

Expand Down
31 changes: 30 additions & 1 deletion pydantic/fields.py
Expand Up @@ -53,6 +53,8 @@ class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False):
validation_alias: str | AliasPath | AliasChoices | None
serialization_alias: str | None
title: str | None
title_priority: int | None
field_title_generator: typing_extensions.Callable[[str], str] | None
description: str | None
examples: list[Any] | None
exclude: bool | None
Expand Down Expand Up @@ -105,6 +107,8 @@ class FieldInfo(_repr.Representation):
validation_alias: The validation alias of the field.
serialization_alias: The serialization alias of the field.
title: The title of the field.
title_priority: Priority of the field's title. This affects whether a title generator is used.
field_title_generator: A callable that takes a field name and returns title for it.
description: The description of the field.
examples: List of examples of the field.
exclude: Whether to exclude the field from the model serialization.
Expand All @@ -129,6 +133,8 @@ class FieldInfo(_repr.Representation):
validation_alias: str | AliasPath | AliasChoices | None
serialization_alias: str | None
title: str | None
title_priority: int | None
field_title_generator: typing.Callable[[str], str] | None
description: str | None
examples: list[Any] | None
exclude: bool | None
Expand All @@ -152,6 +158,8 @@ class FieldInfo(_repr.Representation):
'validation_alias',
'serialization_alias',
'title',
'title_priority',
'field_title_generator',
'description',
'examples',
'exclude',
Expand Down Expand Up @@ -213,6 +221,8 @@ def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None:
self.serialization_alias = kwargs.pop('serialization_alias', None)
alias_is_set = any(alias is not None for alias in (self.alias, self.validation_alias, self.serialization_alias))
self.alias_priority = kwargs.pop('alias_priority', None) or 2 if alias_is_set else None
self.field_title_generator = kwargs.pop('field_title_generator', None)
self.title_priority = kwargs.pop('title_priority', None) or 2 if self.title is not None else None
NeevCohen marked this conversation as resolved.
Show resolved Hide resolved
self.description = kwargs.pop('description', None)
self.examples = kwargs.pop('examples', None)
self.exclude = kwargs.pop('exclude', None)
Expand Down Expand Up @@ -633,6 +643,7 @@ class _EmptyKwargs(typing_extensions.TypedDict):
validation_alias=None,
serialization_alias=None,
title=None,
title_priority=None,
description=None,
examples=None,
exclude=None,
Expand Down Expand Up @@ -668,6 +679,8 @@ def Field( # noqa: C901
validation_alias: str | AliasPath | AliasChoices | None = _Unset,
serialization_alias: str | None = _Unset,
title: str | None = _Unset,
title_priority: int | None = _Unset,
field_title_generator: typing_extensions.Callable[[str], str] | None = _Unset,
description: str | None = _Unset,
examples: list[Any] | None = _Unset,
exclude: bool | None = _Unset,
Expand Down Expand Up @@ -714,6 +727,8 @@ def Field( # noqa: C901
validation_alias: Like `alias`, but only affects validation, not serialization.
serialization_alias: Like `alias`, but only affects serialization, not validation.
title: Human-readable title.
title_priority: Priority of the field's title. This affects whether a title generator is used.
field_title_generator: A callable that takes a field name and returns title for it.
description: Human-readable description.
examples: Example values for this field.
exclude: Whether to exclude the field from the model serialization.
Expand Down Expand Up @@ -830,6 +845,8 @@ def Field( # noqa: C901
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
title_priority=title_priority,
field_title_generator=field_title_generator,
description=description,
examples=examples,
exclude=exclude,
Expand Down Expand Up @@ -969,6 +986,7 @@ class ComputedFieldInfo:
alias: The alias of the property to be used during serialization.
alias_priority: The priority of the alias. This affects whether an alias generator is used.
title: Title of the computed field to include in the serialization JSON schema.
title_priority: Priority of the title. This affects whether a title generator is used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
title_priority: Priority of the title. This affects whether a title generator is used.
title_priority: Priority of the title. This affects whether a title generator is used.
field_title_generator: A callable that takes a field name and returns title for it.

description: Description of the computed field to include in the serialization JSON schema.
deprecated: A deprecation message, an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport,
or a boolean. If `True`, a default deprecation message will be emitted when accessing the field.
Expand All @@ -983,6 +1001,8 @@ class ComputedFieldInfo:
alias: str | None
alias_priority: int | None
title: str | None
title_priority: int | None
field_title_generator: typing.Callable[[str], str] | None
description: str | None
deprecated: Deprecated | str | bool | None
examples: list[Any] | None
Expand Down Expand Up @@ -1022,6 +1042,8 @@ def computed_field(
alias: str | None = None,
alias_priority: int | None = None,
title: str | None = None,
title_priority: int | None = None,
field_title_generator: typing.Callable[[str], str] | None = None,
description: str | None = None,
deprecated: Deprecated | str | bool | None = None,
examples: list[Any] | None = None,
Expand All @@ -1044,6 +1066,8 @@ def computed_field(
alias: str | None = None,
alias_priority: int | None = None,
title: str | None = None,
title_priority: int | None = None,
field_title_generator: typing.Callable[[str], str] | None = None,
description: str | None = None,
deprecated: Deprecated | str | bool | None = None,
examples: list[Any] | None = None,
Expand Down Expand Up @@ -1174,6 +1198,8 @@ def _private_property(self) -> int:
alias: alias to use when serializing this computed field, only used when `by_alias=True`
alias_priority: priority of the alias. This affects whether an alias generator is used
title: Title to use when including this computed field in JSON Schema
title_priority: Priority of the title. This affects whether a title generator is used.
field_title_generator: A callable that takes a field name and returns title for it.
description: Description to use when including this computed field in JSON Schema, defaults to the function's
docstring
deprecated: A deprecation message (or an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport).
Expand All @@ -1193,7 +1219,7 @@ def _private_property(self) -> int:
"""

def dec(f: Any) -> Any:
nonlocal description, deprecated, return_type, alias_priority
nonlocal description, deprecated, return_type, alias_priority, title_priority
unwrapped = _decorators.unwrap_wrapped_function(f)

if description is None and unwrapped.__doc__:
Expand All @@ -1205,6 +1231,7 @@ def dec(f: Any) -> Any:
# if the function isn't already decorated with `@property` (or another descriptor), then we wrap it now
f = _decorators.ensure_property(f)
alias_priority = (alias_priority or 2) if alias is not None else None
title_priority = (title_priority or 2) if title is not None else None

if repr is None:
repr_: bool = not _wrapped_property_is_private(property_=f)
Expand All @@ -1217,6 +1244,8 @@ def dec(f: Any) -> Any:
alias,
alias_priority,
title,
title_priority,
field_title_generator,
description,
deprecated,
examples,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_json_schema.py
Expand Up @@ -232,8 +232,10 @@ class Model(BaseModel):

def test_schema_repr():
s = Field(4, title='Foo is Great')
assert str(s) == "annotation=NoneType required=False default=4 title='Foo is Great'"
assert repr(s) == "FieldInfo(annotation=NoneType, required=False, default=4, title='Foo is Great')"
assert str(s) == "annotation=NoneType required=False default=4 title='Foo is Great' title_priority=2"
assert (
repr(s) == "FieldInfo(annotation=NoneType, required=False, default=4, title='Foo is Great', title_priority=2)"
)


def test_schema_class_by_alias():
Expand Down