Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression in the creation of the __init__ of dataclasses #1812

Merged
merged 5 commits into from Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions ChangeLog
Expand Up @@ -25,7 +25,10 @@ What's New in astroid 2.12.11?
==============================
Release date: TBA

* Fixed a regression in the creation of the ``__init__`` of dataclasses with
multiple inheritance.

Closes PyCQA/pylint#7434

What's New in astroid 2.12.10?
==============================
Expand Down
69 changes: 42 additions & 27 deletions astroid/brain/brain_dataclasses.py
Expand Up @@ -181,39 +181,54 @@ def _find_arguments_from_base_classes(
node: nodes.ClassDef, skippable_names: set[str]
) -> tuple[str, str]:
"""Iterate through all bases and add them to the list of arguments to add to the init."""
prev_pos_only = ""
prev_kw_only = ""
for base in node.mro():
pos_only_store: dict[str, tuple[str | None, str | None]] = {}
kw_only_store: dict[str, tuple[str | None, str | None]] = {}
# See TODO down below
# all_have_defaults = True

for base in reversed(node.mro()):
if not base.is_dataclass:
continue
try:
base_init: nodes.FunctionDef = base.locals["__init__"][0]
except KeyError:
continue

# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=skippable_names)
try:
new_prev_pos_only, new_prev_kw_only = arguments.split("*, ")
except ValueError:
new_prev_pos_only, new_prev_kw_only = arguments, ""

if new_prev_pos_only:
# The split on '*, ' can crete a pos_only string that consists only of a comma
if new_prev_pos_only == ", ":
new_prev_pos_only = ""
elif not new_prev_pos_only.endswith(", "):
new_prev_pos_only += ", "

# Dataclasses put last seen arguments at the front of the init
prev_pos_only = new_prev_pos_only + prev_pos_only
prev_kw_only = new_prev_kw_only + prev_kw_only

# Add arguments to skippable arguments
skippable_names.update(arg.name for arg in base_init.args.args)
skippable_names.update(arg.name for arg in base_init.args.kwonlyargs)

return prev_pos_only, prev_kw_only
pos_only, kw_only = base_init.args._get_arguments_data()
for posarg, data in pos_only.items():
if posarg in skippable_names:
continue
# if data[1] is None:
# if all_have_defaults and pos_only_store:
# # TODO: This should return an Uninferable as this would raise
# # a TypeError at runtime. However, transforms can't return
# # Uninferables currently.
# pass
# all_have_defaults = False
pos_only_store[posarg] = data

for kwarg, data in kw_only.items():
if kwarg in skippable_names:
continue
kw_only_store[kwarg] = data

pos_only, kw_only = "", ""
for pos_arg, data in pos_only_store.items():
pos_only += pos_arg
if data[0]:
pos_only += ": " + data[0]
if data[1]:
pos_only += " = " + data[1]
pos_only += ", "
for kw_arg, data in kw_only_store.items():
kw_only += kw_arg
if data[0]:
kw_only += ": " + data[0]
if data[1]:
kw_only += " = " + data[1]
kw_only += ", "

return pos_only, kw_only


def _generate_dataclass_init(
Expand Down Expand Up @@ -282,7 +297,7 @@ def _generate_dataclass_init(
params_string += ", "

if prev_kw_only:
params_string += "*, " + prev_kw_only + ", "
params_string += "*, " + prev_kw_only
Comment on lines -285 to +300
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change? If it's harmless, then consider not making the change to reduce the diff. The elif/else branches below also have the same trailing separator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prev_kw_only now takes care of this itself in the new helper function so there is not need to add it here anymore.

Edit: Let me know if you want me to change this. Otherwise I think this can be merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that explains it. Merge ho!

if kw_only_decorated:
params_string += ", ".join(params) + ", "
elif kw_only_decorated:
Expand Down
69 changes: 69 additions & 0 deletions astroid/nodes/node_classes.py
Expand Up @@ -834,6 +834,75 @@ def format_args(self, *, skippable_names: set[str] | None = None) -> str:
result.append(f"**{self.kwarg}")
return ", ".join(result)

def _get_arguments_data(
self,
) -> tuple[
dict[str, tuple[str | None, str | None]],
dict[str, tuple[str | None, str | None]],
]:
"""Get the arguments as dictionary with information about typing and defaults.

The return tuple contains a dictionary for positional and keyword arguments with their typing
and their default value, if any.
The method follows a similar order as format_args but instead of formatting into a string it
returns the data that is used to do so.
"""
pos_only: dict[str, tuple[str | None, str | None]] = {}
kw_only: dict[str, tuple[str | None, str | None]] = {}

# Setup and match defaults with arguments
positional_only_defaults = []
positional_or_keyword_defaults = self.defaults
if self.defaults:
args = self.args or []
positional_or_keyword_defaults = self.defaults[-len(args) :]
positional_only_defaults = self.defaults[: len(self.defaults) - len(args)]

for index, posonly in enumerate(self.posonlyargs):
annotation, default = self.posonlyargs_annotations[index], None
if annotation is not None:
annotation = annotation.as_string()
if positional_only_defaults:
default = positional_only_defaults[index].as_string()
pos_only[posonly.name] = (annotation, default)

for index, arg in enumerate(self.args):
annotation, default = self.annotations[index], None
if annotation is not None:
annotation = annotation.as_string()
if positional_or_keyword_defaults:
defaults_offset = len(self.args) - len(positional_or_keyword_defaults)
default_index = index - defaults_offset
if (
default_index > -1
and positional_or_keyword_defaults[default_index] is not None
):
default = positional_or_keyword_defaults[default_index].as_string()
pos_only[arg.name] = (annotation, default)

if self.vararg:
annotation = self.varargannotation
if annotation is not None:
annotation = annotation.as_string()
pos_only[self.vararg] = (annotation, None)

for index, kwarg in enumerate(self.kwonlyargs):
annotation = self.kwonlyargs_annotations[index]
if annotation is not None:
annotation = annotation.as_string()
default = self.kw_defaults[index]
if default is not None:
default = default.as_string()
kw_only[kwarg.name] = (annotation, default)

if self.kwarg:
annotation = self.kwargannotation
if annotation is not None:
annotation = annotation.as_string()
kw_only[self.kwarg] = (annotation, None)

return pos_only, kw_only

def default_value(self, argname):
"""Get the default value for an argument.

Expand Down
70 changes: 70 additions & 0 deletions tests/unittest_brain_dataclasses.py
Expand Up @@ -918,6 +918,7 @@ def test_dataclass_with_multiple_inheritance() -> None:
"""Regression test for dataclasses with multiple inheritance.

Reported in https://github.com/PyCQA/pylint/issues/7427
Reported in https://github.com/PyCQA/pylint/issues/7434
"""
first, second, overwritten, overwriting, mixed = astroid.extract_node(
"""
Expand Down Expand Up @@ -991,6 +992,75 @@ class ChildWithMixedParents(BaseParent, NotADataclassParent):
assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"]
assert [a.value for a in mixed_init.args.defaults] == [1, 3]

first = astroid.extract_node(
"""
from dataclasses import dataclass

@dataclass
class BaseParent:
required: bool

@dataclass
class FirstChild(BaseParent):
...

@dataclass
class SecondChild(BaseParent):
optional: bool = False

@dataclass
class GrandChild(FirstChild, SecondChild):
...

GrandChild.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "required", "optional"]
assert [a.value for a in first_init.args.defaults] == [False]


@pytest.mark.xfail(reason="Transforms returning Uninferable isn't supported.")
def test_dataclass_non_default_argument_after_default() -> None:
"""Test that a non-default argument after a default argument is not allowed.

This should succeed, but the dataclass brain is a transform
which currently can't return an Uninferable correctly. Therefore, we can't
set the dataclass ClassDef node to be Uninferable currently.
Eventually it can be merged into test_dataclass_with_multiple_inheritance.
"""

impossible = astroid.extract_node(
"""
from dataclasses import dataclass

@dataclass
class BaseParent:
required: bool

@dataclass
class FirstChild(BaseParent):
...

@dataclass
class SecondChild(BaseParent):
optional: bool = False

@dataclass
class ThirdChild:
other: bool = False

@dataclass
class ImpossibleGrandChild(FirstChild, SecondChild, ThirdChild):
...

ImpossibleGrandChild() #@
"""
)

assert next(impossible.infer()) is Uninferable


def test_dataclass_inits_of_non_dataclasses() -> None:
"""Regression test for __init__ mangling for non dataclasses.
Expand Down