diff --git a/changelog/9820.bugfix.rst b/changelog/9820.bugfix.rst new file mode 100644 index 00000000000..03af940f2c4 --- /dev/null +++ b/changelog/9820.bugfix.rst @@ -0,0 +1 @@ +Fix comparison of ``dataclasses`` with ``InitVar``. diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 03167ddd471..b1f168767be 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -437,8 +437,10 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: if not has_default_eq(left): return [] if isdatacls(left): - all_fields = left.__dataclass_fields__ - fields_to_check = [field for field, info in all_fields.items() if info.compare] + import dataclasses + + all_fields = dataclasses.fields(left) + fields_to_check = [info.name for info in all_fields if info.compare] elif isattrs(left): all_fields = left.__attrs_attrs__ fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] diff --git a/testing/example_scripts/dataclasses/test_compare_initvar.py b/testing/example_scripts/dataclasses/test_compare_initvar.py new file mode 100644 index 00000000000..d859634ddd5 --- /dev/null +++ b/testing/example_scripts/dataclasses/test_compare_initvar.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from dataclasses import InitVar + + +@dataclass +class Foo: + init_only: InitVar[int] + real_attr: int + + +def test_demonstrate(): + assert Foo(1, 2) == Foo(1, 3) diff --git a/testing/test_assertion.py b/testing/test_assertion.py index d37ee72a2ec..40a3e0021d0 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -882,6 +882,13 @@ def test_data_classes_with_custom_eq(self, pytester: Pytester) -> None: result.assert_outcomes(failed=1, passed=0) result.stdout.no_re_match_line(".*Differing attributes.*") + def test_data_classes_with_initvar(self, pytester: Pytester) -> None: + p = pytester.copy_example("dataclasses/test_compare_initvar.py") + # issue 9820 + result = pytester.runpytest(p, "-vv") + result.assert_outcomes(failed=1, passed=0) + result.stdout.no_re_match_line(".*AttributeError.*") + class TestAssert_reprcompare_attrsclass: def test_attrs(self) -> None: