Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the __init__ of dataclassess with multiple inheritance #1774

Merged
merged 1 commit into from Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions ChangeLog
Expand Up @@ -17,6 +17,10 @@ What's New in astroid 2.12.9?
=============================
Release date: TBA

* Fixed creation of the ``__init__`` of ``dataclassess`` with multiple inheritance.

Closes PyCQA/pylint#7427

* Fixed a crash on ``namedtuples`` that use ``typename`` to specify their name.

Closes PyCQA/pylint#7429
Expand Down
62 changes: 42 additions & 20 deletions astroid/brain/brain_dataclasses.py
Expand Up @@ -177,6 +177,45 @@ def _check_generate_dataclass_init(node: nodes.ClassDef) -> bool:
)


def _find_arguments_from_base_classes(
node: nodes.ClassDef, skippable_names: set[str]
) -> tuple[str, str]:
"""Iterate through all bases and add them to the list of arguments to add to the init."""
prev_pos_only = ""
prev_kw_only = ""
for base in node.mro():
if not base.is_dataclass:
continue
try:
base_init: nodes.FunctionDef = base.locals["__init__"][0]
except KeyError:
continue

# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=skippable_names)
try:
new_prev_pos_only, new_prev_kw_only = arguments.split("*, ")
except ValueError:
new_prev_pos_only, new_prev_kw_only = arguments, ""

if new_prev_pos_only:
# The split on '*, ' can crete a pos_only string that consists only of a comma
if new_prev_pos_only == ", ":
new_prev_pos_only = ""
elif not new_prev_pos_only.endswith(", "):
new_prev_pos_only += ", "

# Dataclasses put last seen arguments at the front of the init
prev_pos_only = new_prev_pos_only + prev_pos_only
prev_kw_only = new_prev_kw_only + prev_kw_only

# Add arguments to skippable arguments
skippable_names.update(arg.name for arg in base_init.args.args)
skippable_names.update(arg.name for arg in base_init.args.kwonlyargs)

return prev_pos_only, prev_kw_only


def _generate_dataclass_init(
node: nodes.ClassDef, assigns: list[nodes.AnnAssign], kw_only_decorated: bool
) -> str:
Expand Down Expand Up @@ -228,26 +267,9 @@ def _generate_dataclass_init(
if not init_var:
assignments.append(assignment_str)

try:
base = next(next(iter(node.bases)).infer())
if not isinstance(base, nodes.ClassDef):
raise InferenceError
base_init: nodes.FunctionDef | None = base.locals["__init__"][0]
except (StopIteration, InferenceError, KeyError):
base_init = None

prev_pos_only = ""
prev_kw_only = ""
if base_init and base.is_dataclass:
# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=assign_names)[6:]
try:
prev_pos_only, prev_kw_only = arguments.split("*, ")
except ValueError:
prev_pos_only, prev_kw_only = arguments, ""

if prev_pos_only and not prev_pos_only.endswith(", "):
prev_pos_only += ", "
prev_pos_only, prev_kw_only = _find_arguments_from_base_classes(
node, set(assign_names + ["self"])
)

# Construct the new init method paramter string
params_string = "self, "
Expand Down
132 changes: 132 additions & 0 deletions tests/unittest_brain_dataclasses.py
Expand Up @@ -912,3 +912,135 @@ class GoodExampleClass(GoodExampleParentClass):
good_init: bases.UnboundMethod = next(good_node.infer())
assert bad_init.args.defaults
assert [a.name for a in good_init.args.args] == ["self", "xyz"]


def test_dataclass_with_multiple_inheritance() -> None:
"""Regression test for dataclasses with multiple inheritance.

Reported in https://github.com/PyCQA/pylint/issues/7427
"""
first, second, overwritten, overwriting, mixed = astroid.extract_node(
"""
from dataclasses import dataclass

@dataclass
class BaseParent:
_abc: int = 1

@dataclass
class AnotherParent:
ef: int = 2

@dataclass
class FirstChild(BaseParent, AnotherParent):
ghi: int = 3

@dataclass
class ConvolutedParent(AnotherParent):
'''Convoluted Parent'''

@dataclass
class SecondChild(BaseParent, ConvolutedParent):
jkl: int = 4

@dataclass
class OverwritingParent:
ef: str = "2"

@dataclass
class OverwrittenChild(OverwritingParent, AnotherParent):
'''Overwritten Child'''

@dataclass
class OverwritingChild(BaseParent, AnotherParent):
_abc: float = 1.0
ef: float = 2.0

class NotADataclassParent:
ef: int = 2

@dataclass
class ChildWithMixedParents(BaseParent, NotADataclassParent):
ghi: int = 3

FirstChild.__init__ #@
SecondChild.__init__ #@
OverwrittenChild.__init__ #@
OverwritingChild.__init__ #@
ChildWithMixedParents.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "ef", "_abc", "ghi"]
assert [a.value for a in first_init.args.defaults] == [2, 1, 3]

second_init: bases.UnboundMethod = next(second.infer())
assert [a.name for a in second_init.args.args] == ["self", "ef", "_abc", "jkl"]
assert [a.value for a in second_init.args.defaults] == [2, 1, 4]

overwritten_init: bases.UnboundMethod = next(overwritten.infer())
assert [a.name for a in overwritten_init.args.args] == ["self", "ef"]
assert [a.value for a in overwritten_init.args.defaults] == ["2"]

overwriting_init: bases.UnboundMethod = next(overwriting.infer())
assert [a.name for a in overwriting_init.args.args] == ["self", "_abc", "ef"]
assert [a.value for a in overwriting_init.args.defaults] == [1.0, 2.0]

mixed_init: bases.UnboundMethod = next(mixed.infer())
assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"]
assert [a.value for a in mixed_init.args.defaults] == [1, 3]


def test_dataclass_inits_of_non_dataclasses() -> None:
"""Regression test for __init__ mangling for non dataclasses.

Regression test against changes tested in test_dataclass_with_multiple_inheritance
"""
first, second, third = astroid.extract_node(
"""
from dataclasses import dataclass

@dataclass
class DataclassParent:
_abc: int = 1


class NotADataclassParent:
ef: int = 2


class FirstChild(DataclassParent, NotADataclassParent):
ghi: int = 3


class SecondChild(DataclassParent, NotADataclassParent):
ghi: int = 3

def __init__(self, ef: int = 3):
self.ef = ef


class ThirdChild(NotADataclassParent, DataclassParent):
ghi: int = 3

def __init__(self, ef: int = 3):
self.ef = ef

FirstChild.__init__ #@
SecondChild.__init__ #@
ThirdChild.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "_abc"]
assert [a.value for a in first_init.args.defaults] == [1]

second_init: bases.UnboundMethod = next(second.infer())
assert [a.name for a in second_init.args.args] == ["self", "ef"]
assert [a.value for a in second_init.args.defaults] == [3]

third_init: bases.UnboundMethod = next(third.infer())
assert [a.name for a in third_init.args.args] == ["self", "ef"]
assert [a.value for a in third_init.args.defaults] == [3]