From 1ef95b64a0e68eda99ea9e72957420f5a45e40f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=B4mulo=20Collopy?= Date: Sun, 23 Jul 2023 22:40:08 +0200 Subject: [PATCH] adds model_name option in ConfigDict --- pydantic/_internal/_config.py | 2 ++ pydantic/_internal/_model_construction.py | 1 + pydantic/config.py | 1 + tests/test_config.py | 28 +++++++++++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/pydantic/_internal/_config.py b/pydantic/_internal/_config.py index 4975431c297..9dcbf255508 100644 --- a/pydantic/_internal/_config.py +++ b/pydantic/_internal/_config.py @@ -28,6 +28,7 @@ class ConfigWrapper: # all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they # stop matching title: str | None + model_name: str | None str_to_lower: bool str_to_upper: bool str_strip_whitespace: bool @@ -165,6 +166,7 @@ def __repr__(self): config_defaults = ConfigDict( title=None, + model_name=None, str_to_lower=False, str_to_upper=False, str_strip_whitespace=False, diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 65fe758540f..d9578404030 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -93,6 +93,7 @@ def __new__( base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases) config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs) + cls_name = config_wrapper.config_dict.get("model_name") or cls_name namespace['model_config'] = config_wrapper.config_dict private_attributes = inspect_namespace( namespace, config_wrapper.ignored_types, class_vars, base_field_names diff --git a/pydantic/config.py b/pydantic/config.py index 09229e213d6..48fcf43acb7 100644 --- a/pydantic/config.py +++ b/pydantic/config.py @@ -128,6 +128,7 @@ class without an annotation and has a type that is not in this tuple (or otherwi """ title: str | None + model_name: str | None str_to_lower: bool str_to_upper: bool str_strip_whitespace: bool diff --git a/tests/test_config.py b/tests/test_config.py index b85692783c2..f9fee06c353 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -23,6 +23,7 @@ from pydantic.config import ConfigDict from pydantic.dataclasses import dataclass as pydantic_dataclass from pydantic.errors import PydanticUserError +from pydantic.v1.schema import get_model_name_map if sys.version_info < (3, 9): from typing_extensions import Annotated @@ -484,6 +485,33 @@ def test_invalid_config_keys(): def my_function(): pass +def test_config_model_name() -> None: + CLIENT_USER_MODEL_NAME = "ClientUser" + BUSINESS_USER_MODEL_NAME = "BusinessUser" + + + def _get_business_user_class(): + class User(BaseModel): + model_config = ConfigDict(model_name=BUSINESS_USER_MODEL_NAME) + + return User + + def _get_client_user_class(): + class User(BaseModel): + model_config = ConfigDict(model_name=CLIENT_USER_MODEL_NAME) + + return User + + BusinessUser = _get_business_user_class() + ClientUser = _get_client_user_class() + + name_map = get_model_name_map({BusinessUser, ClientUser}) + assert name_map[BusinessUser] == BUSINESS_USER_MODEL_NAME + assert name_map[ClientUser] == CLIENT_USER_MODEL_NAME + + assert BusinessUser().model_json_schema()["title"] == BUSINESS_USER_MODEL_NAME + assert ClientUser().model_json_schema()["title"] == CLIENT_USER_MODEL_NAME + def test_multiple_inheritance_config(): class Parent(BaseModel):