diff --git a/changes/2422-PrettyWood.md b/changes/2422-PrettyWood.md new file mode 100644 index 00000000000..32e01e3446a --- /dev/null +++ b/changes/2422-PrettyWood.md @@ -0,0 +1 @@ +do not overwrite `__hash__` when explicitly declared diff --git a/pydantic/main.py b/pydantic/main.py index 6c818d38e22..5fd83f7bb9c 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -231,6 +231,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 slots: SetStr = namespace.get('__slots__', ()) slots = {slots} if isinstance(slots, str) else set(slots) class_vars: SetStr = set() + hash_func: Optional[Callable[[Any], int]] = None for base in reversed(bases): if _is_base_model_class_defined and issubclass(base, BaseModel) and base != BaseModel: @@ -241,6 +242,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 post_root_validators += base.__post_root_validators__ private_attributes.update(base.__private_attributes__) class_vars.update(base.__class_vars__) + hash_func = base.__hash__ config_kwargs = {key: kwargs.pop(key) for key in kwargs.keys() & BaseConfig.__dict__.keys()} config_from_namespace = namespace.get('Config') @@ -332,6 +334,9 @@ def is_untouched(v: Any) -> bool: json_encoder = pydantic_encoder pre_rv_new, post_rv_new = extract_root_validators(namespace) + if hash_func is None: + hash_func = generate_hash_function(config.frozen) + exclude_from_namespace = fields | private_attributes.keys() | {'__slots__'} new_namespace = { '__config__': config, @@ -344,7 +349,7 @@ def is_untouched(v: Any) -> bool: '__custom_root_type__': _custom_root_type, '__private_attributes__': private_attributes, '__slots__': slots | private_attributes.keys(), - '__hash__': generate_hash_function(config.frozen), + '__hash__': hash_func, '__class_vars__': class_vars, **{n: v for n, v in namespace.items() if n not in exclude_from_namespace}, } diff --git a/tests/test_main.py b/tests/test_main.py index e193772b1f1..30ac7d3daa6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -395,6 +395,20 @@ class TestModel(BaseModel): assert "unhashable type: 'TestModel'" in exc_info.value.args[0] +def test_with_explicit_hash(): + class Foo(BaseModel): + x: int + + def __hash__(self): + return self.x ** 2 + + class Bar(Foo): + y: float + + assert hash(Foo(x=2)) == 4 + assert hash(Bar(x=2, y=1.1)) == 4 + + def test_frozen_with_hashable_fields_are_hashable(): class TestModel(BaseModel): a: int = 10