From e41155d354a5943a559e3084c0b6f99b5e171231 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 7 Sep 2022 12:30:19 +0200 Subject: [PATCH] Fixed the ``__init__`` of ``dataclassess`` with multiple inheritance (#1774) --- ChangeLog | 4 + astroid/brain/brain_dataclasses.py | 62 ++++++++----- tests/unittest_brain_dataclasses.py | 132 ++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 20 deletions(-) diff --git a/ChangeLog b/ChangeLog index 771e49491..b038fe7db 100644 --- a/ChangeLog +++ b/ChangeLog @@ -12,6 +12,10 @@ What's New in astroid 2.12.9? ============================= Release date: TBA +* Fixed creation of the ``__init__`` of ``dataclassess`` with multiple inheritance. + + Closes PyCQA/pylint#7427 + * Fixed a crash on ``namedtuples`` that use ``typename`` to specify their name. Closes PyCQA/pylint#7429 diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index 629024902..264957e00 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -177,6 +177,45 @@ def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool: ) +def _find_arguments_from_base_classes( + node: nodes.ClassDef, skippable_names: set[str] +) -> tuple[str, str]: + """Iterate through all bases and add them to the list of arguments to add to the init.""" + prev_pos_only = "" + prev_kw_only = "" + for base in node.mro(): + if not base.is_dataclass: + continue + try: + base_init: nodes.FunctionDef = base.locals["__init__"][0] + except KeyError: + continue + + # Skip the self argument and check for duplicate arguments + arguments = base_init.args.format_args(skippable_names=skippable_names) + try: + new_prev_pos_only, new_prev_kw_only = arguments.split("*, ") + except ValueError: + new_prev_pos_only, new_prev_kw_only = arguments, "" + + if new_prev_pos_only: + # The split on '*, ' can crete a pos_only string that consists only of a comma + if new_prev_pos_only == ", ": + new_prev_pos_only = "" + elif not new_prev_pos_only.endswith(", "): + new_prev_pos_only += ", " + + # Dataclasses put last seen arguments at the front of the init + prev_pos_only = new_prev_pos_only + prev_pos_only + prev_kw_only = new_prev_kw_only + prev_kw_only + + # Add arguments to skippable arguments + skippable_names.update(arg.name for arg in base_init.args.args) + skippable_names.update(arg.name for arg in base_init.args.kwonlyargs) + + return prev_pos_only, prev_kw_only + + def _generate_dataclass_init( node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool ) -> str: @@ -228,26 +267,9 @@ def _generate_dataclass_init( if not init_var: assignments.append(assignment_str) - try: - base = next(next(iter(node.bases)).infer()) - if not isinstance(base, nodes.ClassDef): - raise InferenceError - base_init: nodes.FunctionDef | None = base.locals["__init__"][0] - except (StopIteration, InferenceError, KeyError): - base_init = None - - prev_pos_only = "" - prev_kw_only = "" - if base_init and base.is_dataclass: - # Skip the self argument and check for duplicate arguments - arguments = base_init.args.format_args(skippable_names=assign_names)[6:] - try: - prev_pos_only, prev_kw_only = arguments.split("*, ") - except ValueError: - prev_pos_only, prev_kw_only = arguments, "" - - if prev_pos_only and not prev_pos_only.endswith(", "): - prev_pos_only += ", " + prev_pos_only, prev_kw_only = _find_arguments_from_base_classes( + node, set(assign_names + ["self"]) + ) # Construct the new init method paramter string params_string = "self, " diff --git a/tests/unittest_brain_dataclasses.py b/tests/unittest_brain_dataclasses.py index 2f59d4c33..7d69b3591 100644 --- a/tests/unittest_brain_dataclasses.py +++ b/tests/unittest_brain_dataclasses.py @@ -912,3 +912,135 @@ class GoodExampleClass(GoodExampleParentClass): good_init: bases.UnboundMethod = next(good_node.infer()) assert bad_init.args.defaults assert [a.name for a in good_init.args.args] == ["self", "xyz"] + + +def test_dataclass_with_multiple_inheritance() -> None: + """Regression test for dataclasses with multiple inheritance. + + Reported in https://github.com/PyCQA/pylint/issues/7427 + """ + first, second, overwritten, overwriting, mixed = astroid.extract_node( + """ + from dataclasses import dataclass + + @dataclass + class BaseParent: + _abc: int = 1 + + @dataclass + class AnotherParent: + ef: int = 2 + + @dataclass + class FirstChild(BaseParent, AnotherParent): + ghi: int = 3 + + @dataclass + class ConvolutedParent(AnotherParent): + '''Convoluted Parent''' + + @dataclass + class SecondChild(BaseParent, ConvolutedParent): + jkl: int = 4 + + @dataclass + class OverwritingParent: + ef: str = "2" + + @dataclass + class OverwrittenChild(OverwritingParent, AnotherParent): + '''Overwritten Child''' + + @dataclass + class OverwritingChild(BaseParent, AnotherParent): + _abc: float = 1.0 + ef: float = 2.0 + + class NotADataclassParent: + ef: int = 2 + + @dataclass + class ChildWithMixedParents(BaseParent, NotADataclassParent): + ghi: int = 3 + + FirstChild.__init__ #@ + SecondChild.__init__ #@ + OverwrittenChild.__init__ #@ + OverwritingChild.__init__ #@ + ChildWithMixedParents.__init__ #@ + """ + ) + + first_init: bases.UnboundMethod = next(first.infer()) + assert [a.name for a in first_init.args.args] == ["self", "ef", "_abc", "ghi"] + assert [a.value for a in first_init.args.defaults] == [2, 1, 3] + + second_init: bases.UnboundMethod = next(second.infer()) + assert [a.name for a in second_init.args.args] == ["self", "ef", "_abc", "jkl"] + assert [a.value for a in second_init.args.defaults] == [2, 1, 4] + + overwritten_init: bases.UnboundMethod = next(overwritten.infer()) + assert [a.name for a in overwritten_init.args.args] == ["self", "ef"] + assert [a.value for a in overwritten_init.args.defaults] == ["2"] + + overwriting_init: bases.UnboundMethod = next(overwriting.infer()) + assert [a.name for a in overwriting_init.args.args] == ["self", "_abc", "ef"] + assert [a.value for a in overwriting_init.args.defaults] == [1.0, 2.0] + + mixed_init: bases.UnboundMethod = next(mixed.infer()) + assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"] + assert [a.value for a in mixed_init.args.defaults] == [1, 3] + + +def test_dataclass_inits_of_non_dataclasses() -> None: + """Regression test for __init__ mangling for non dataclasses. + + Regression test against changes tested in test_dataclass_with_multiple_inheritance + """ + first, second, third = astroid.extract_node( + """ + from dataclasses import dataclass + + @dataclass + class DataclassParent: + _abc: int = 1 + + + class NotADataclassParent: + ef: int = 2 + + + class FirstChild(DataclassParent, NotADataclassParent): + ghi: int = 3 + + + class SecondChild(DataclassParent, NotADataclassParent): + ghi: int = 3 + + def __init__(self, ef: int = 3): + self.ef = ef + + + class ThirdChild(NotADataclassParent, DataclassParent): + ghi: int = 3 + + def __init__(self, ef: int = 3): + self.ef = ef + + FirstChild.__init__ #@ + SecondChild.__init__ #@ + ThirdChild.__init__ #@ + """ + ) + + first_init: bases.UnboundMethod = next(first.infer()) + assert [a.name for a in first_init.args.args] == ["self", "_abc"] + assert [a.value for a in first_init.args.defaults] == [1] + + second_init: bases.UnboundMethod = next(second.infer()) + assert [a.name for a in second_init.args.args] == ["self", "ef"] + assert [a.value for a in second_init.args.defaults] == [3] + + third_init: bases.UnboundMethod = next(third.infer()) + assert [a.name for a in third_init.args.args] == ["self", "ef"] + assert [a.value for a in third_init.args.defaults] == [3]