Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inheritance of hash function for frozen models #6789

Merged
merged 4 commits into from Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 22 additions & 6 deletions pydantic/_internal/_model_construction.py
Expand Up @@ -116,12 +116,8 @@ def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}

if '__hash__' not in namespace and config_wrapper.frozen:

def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))

namespace['__hash__'] = hash_func
if config_wrapper.frozen:
set_default_hash_func(namespace, bases)

cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore

Expand Down Expand Up @@ -359,6 +355,26 @@ def inspect_namespace( # noqa C901
return private_attributes


def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
if '__hash__' in namespace:
return

base_hash_func = None
for base in bases:
base_hash_func = getattr(base, '__hash__', PydanticUndefined)
if base_hash_func is not PydanticUndefined:
break

if base_hash_func is None:
# This will be the case for `BaseModel` since it defines `__eq__` but not `__hash__`.
# In this case, we generate a standard hash function, generally for use with frozen models.

def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))

namespace['__hash__'] = hash_func


def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_main.py
Expand Up @@ -574,6 +574,31 @@ class TestModel(BaseModel):
assert hash(m) != hash(m4)


def test_hash_method_is_inherited_for_frozen_models():
from functools import lru_cache

class MyBaseModel(BaseModel):
"""A base model with sensible configurations."""

model_config = ConfigDict(frozen=True)

def __hash__(self):
return hash(id(self))

class MySubClass(MyBaseModel):
x: Dict[str, int]

@lru_cache(maxsize=None)
def cached_method(self):
return len(self.x)

my_instance = MySubClass(x={'a': 1, 'b': 2})
assert my_instance.cached_method() == 2

object.__setattr__(my_instance, 'x', {}) # can't change the "normal" way due to frozen
assert my_instance.cached_method() == 2


@pytest.fixture(name='ValidateAssignmentModel', scope='session')
def validate_assignment_fixture():
class ValidateAssignmentModel(BaseModel):
Expand Down