diff --git a/ChangeLog b/ChangeLog index c10ae3152..1210e850c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -25,7 +25,10 @@ What's New in astroid 2.12.11? ============================== Release date: TBA +* Fixed a regression in the creation of the ``__init__`` of dataclasses with + multiple inheritance. + Closes PyCQA/pylint#7434 What's New in astroid 2.12.10? ============================== diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index 264957e00..5d3c34610 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -181,9 +181,12 @@ def _find_arguments_from_base_classes( node: nodes.ClassDef, skippable_names: set[str] ) -> tuple[str, str]: """Iterate through all bases and add them to the list of arguments to add to the init.""" - prev_pos_only = "" - prev_kw_only = "" - for base in node.mro(): + pos_only_store: dict[str, tuple[str | None, str | None]] = {} + kw_only_store: dict[str, tuple[str | None, str | None]] = {} + # See TODO down below + # all_have_defaults = True + + for base in reversed(node.mro()): if not base.is_dataclass: continue try: @@ -191,29 +194,41 @@ def _find_arguments_from_base_classes( except KeyError: continue - # Skip the self argument and check for duplicate arguments - arguments = base_init.args.format_args(skippable_names=skippable_names) - try: - new_prev_pos_only, new_prev_kw_only = arguments.split("*, ") - except ValueError: - new_prev_pos_only, new_prev_kw_only = arguments, "" - - if new_prev_pos_only: - # The split on '*, ' can crete a pos_only string that consists only of a comma - if new_prev_pos_only == ", ": - new_prev_pos_only = "" - elif not new_prev_pos_only.endswith(", "): - new_prev_pos_only += ", " - - # Dataclasses put last seen arguments at the front of the init - prev_pos_only = new_prev_pos_only + prev_pos_only - prev_kw_only = new_prev_kw_only + prev_kw_only - - # Add arguments to skippable arguments - skippable_names.update(arg.name for arg in base_init.args.args) - skippable_names.update(arg.name for arg in base_init.args.kwonlyargs) - - return prev_pos_only, prev_kw_only + pos_only, kw_only = base_init.args._get_arguments_data() + for posarg, data in pos_only.items(): + if posarg in skippable_names: + continue + # if data[1] is None: + # if all_have_defaults and pos_only_store: + # # TODO: This should return an Uninferable as this would raise + # # a TypeError at runtime. However, transforms can't return + # # Uninferables currently. + # pass + # all_have_defaults = False + pos_only_store[posarg] = data + + for kwarg, data in kw_only.items(): + if kwarg in skippable_names: + continue + kw_only_store[kwarg] = data + + pos_only, kw_only = "", "" + for pos_arg, data in pos_only_store.items(): + pos_only += pos_arg + if data[0]: + pos_only += ": " + data[0] + if data[1]: + pos_only += " = " + data[1] + pos_only += ", " + for kw_arg, data in kw_only_store.items(): + kw_only += kw_arg + if data[0]: + kw_only += ": " + data[0] + if data[1]: + kw_only += " = " + data[1] + kw_only += ", " + + return pos_only, kw_only def _generate_dataclass_init( @@ -282,7 +297,7 @@ def _generate_dataclass_init( params_string += ", " if prev_kw_only: - params_string += "*, " + prev_kw_only + ", " + params_string += "*, " + prev_kw_only if kw_only_decorated: params_string += ", ".join(params) + ", " elif kw_only_decorated: diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py index a4686688b..2f515dbe9 100644 --- a/astroid/nodes/node_classes.py +++ b/astroid/nodes/node_classes.py @@ -834,6 +834,75 @@ def format_args(self, *, skippable_names: set[str] | None = None) -> str: result.append(f"**{self.kwarg}") return ", ".join(result) + def _get_arguments_data( + self, + ) -> tuple[ + dict[str, tuple[str | None, str | None]], + dict[str, tuple[str | None, str | None]], + ]: + """Get the arguments as dictionary with information about typing and defaults. + + The return tuple contains a dictionary for positional and keyword arguments with their typing + and their default value, if any. + The method follows a similar order as format_args but instead of formatting into a string it + returns the data that is used to do so. + """ + pos_only: dict[str, tuple[str | None, str | None]] = {} + kw_only: dict[str, tuple[str | None, str | None]] = {} + + # Setup and match defaults with arguments + positional_only_defaults = [] + positional_or_keyword_defaults = self.defaults + if self.defaults: + args = self.args or [] + positional_or_keyword_defaults = self.defaults[-len(args) :] + positional_only_defaults = self.defaults[: len(self.defaults) - len(args)] + + for index, posonly in enumerate(self.posonlyargs): + annotation, default = self.posonlyargs_annotations[index], None + if annotation is not None: + annotation = annotation.as_string() + if positional_only_defaults: + default = positional_only_defaults[index].as_string() + pos_only[posonly.name] = (annotation, default) + + for index, arg in enumerate(self.args): + annotation, default = self.annotations[index], None + if annotation is not None: + annotation = annotation.as_string() + if positional_or_keyword_defaults: + defaults_offset = len(self.args) - len(positional_or_keyword_defaults) + default_index = index - defaults_offset + if ( + default_index > -1 + and positional_or_keyword_defaults[default_index] is not None + ): + default = positional_or_keyword_defaults[default_index].as_string() + pos_only[arg.name] = (annotation, default) + + if self.vararg: + annotation = self.varargannotation + if annotation is not None: + annotation = annotation.as_string() + pos_only[self.vararg] = (annotation, None) + + for index, kwarg in enumerate(self.kwonlyargs): + annotation = self.kwonlyargs_annotations[index] + if annotation is not None: + annotation = annotation.as_string() + default = self.kw_defaults[index] + if default is not None: + default = default.as_string() + kw_only[kwarg.name] = (annotation, default) + + if self.kwarg: + annotation = self.kwargannotation + if annotation is not None: + annotation = annotation.as_string() + kw_only[self.kwarg] = (annotation, None) + + return pos_only, kw_only + def default_value(self, argname): """Get the default value for an argument. diff --git a/tests/unittest_brain_dataclasses.py b/tests/unittest_brain_dataclasses.py index 7d69b3591..a65a8dec0 100644 --- a/tests/unittest_brain_dataclasses.py +++ b/tests/unittest_brain_dataclasses.py @@ -918,6 +918,7 @@ def test_dataclass_with_multiple_inheritance() -> None: """Regression test for dataclasses with multiple inheritance. Reported in https://github.com/PyCQA/pylint/issues/7427 + Reported in https://github.com/PyCQA/pylint/issues/7434 """ first, second, overwritten, overwriting, mixed = astroid.extract_node( """ @@ -991,6 +992,75 @@ class ChildWithMixedParents(BaseParent, NotADataclassParent): assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"] assert [a.value for a in mixed_init.args.defaults] == [1, 3] + first = astroid.extract_node( + """ + from dataclasses import dataclass + + @dataclass + class BaseParent: + required: bool + + @dataclass + class FirstChild(BaseParent): + ... + + @dataclass + class SecondChild(BaseParent): + optional: bool = False + + @dataclass + class GrandChild(FirstChild, SecondChild): + ... + + GrandChild.__init__ #@ + """ + ) + + first_init: bases.UnboundMethod = next(first.infer()) + assert [a.name for a in first_init.args.args] == ["self", "required", "optional"] + assert [a.value for a in first_init.args.defaults] == [False] + + +@pytest.mark.xfail(reason="Transforms returning Uninferable isn't supported.") +def test_dataclass_non_default_argument_after_default() -> None: + """Test that a non-default argument after a default argument is not allowed. + + This should succeed, but the dataclass brain is a transform + which currently can't return an Uninferable correctly. Therefore, we can't + set the dataclass ClassDef node to be Uninferable currently. + Eventually it can be merged into test_dataclass_with_multiple_inheritance. + """ + + impossible = astroid.extract_node( + """ + from dataclasses import dataclass + + @dataclass + class BaseParent: + required: bool + + @dataclass + class FirstChild(BaseParent): + ... + + @dataclass + class SecondChild(BaseParent): + optional: bool = False + + @dataclass + class ThirdChild: + other: bool = False + + @dataclass + class ImpossibleGrandChild(FirstChild, SecondChild, ThirdChild): + ... + + ImpossibleGrandChild() #@ + """ + ) + + assert next(impossible.infer()) is Uninferable + def test_dataclass_inits_of_non_dataclasses() -> None: """Regression test for __init__ mangling for non dataclasses.