Skip to content

Commit

Permalink
Fix regression in the creation of the __init__ of dataclasses (#1812
Browse files Browse the repository at this point in the history
)

Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
  • Loading branch information
DanielNoord and jacobtylerwalls committed Oct 4, 2022
1 parent 2a1b0d3 commit 1ffe400
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 27 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Expand Up @@ -25,7 +25,10 @@ What's New in astroid 2.12.11?
==============================
Release date: TBA

* Fixed a regression in the creation of the ``__init__`` of dataclasses with
multiple inheritance.

Closes PyCQA/pylint#7434

What's New in astroid 2.12.10?
==============================
Expand Down
69 changes: 42 additions & 27 deletions astroid/brain/brain_dataclasses.py
Expand Up @@ -181,39 +181,54 @@ def _find_arguments_from_base_classes(
node: nodes.ClassDef, skippable_names: set[str]
) -> tuple[str, str]:
"""Iterate through all bases and add them to the list of arguments to add to the init."""
prev_pos_only = ""
prev_kw_only = ""
for base in node.mro():
pos_only_store: dict[str, tuple[str | None, str | None]] = {}
kw_only_store: dict[str, tuple[str | None, str | None]] = {}
# See TODO down below
# all_have_defaults = True

for base in reversed(node.mro()):
if not base.is_dataclass:
continue
try:
base_init: nodes.FunctionDef = base.locals["__init__"][0]
except KeyError:
continue

# Skip the self argument and check for duplicate arguments
arguments = base_init.args.format_args(skippable_names=skippable_names)
try:
new_prev_pos_only, new_prev_kw_only = arguments.split("*, ")
except ValueError:
new_prev_pos_only, new_prev_kw_only = arguments, ""

if new_prev_pos_only:
# The split on '*, ' can crete a pos_only string that consists only of a comma
if new_prev_pos_only == ", ":
new_prev_pos_only = ""
elif not new_prev_pos_only.endswith(", "):
new_prev_pos_only += ", "

# Dataclasses put last seen arguments at the front of the init
prev_pos_only = new_prev_pos_only + prev_pos_only
prev_kw_only = new_prev_kw_only + prev_kw_only

# Add arguments to skippable arguments
skippable_names.update(arg.name for arg in base_init.args.args)
skippable_names.update(arg.name for arg in base_init.args.kwonlyargs)

return prev_pos_only, prev_kw_only
pos_only, kw_only = base_init.args._get_arguments_data()
for posarg, data in pos_only.items():
if posarg in skippable_names:
continue
# if data[1] is None:
# if all_have_defaults and pos_only_store:
# # TODO: This should return an Uninferable as this would raise
# # a TypeError at runtime. However, transforms can't return
# # Uninferables currently.
# pass
# all_have_defaults = False
pos_only_store[posarg] = data

for kwarg, data in kw_only.items():
if kwarg in skippable_names:
continue
kw_only_store[kwarg] = data

pos_only, kw_only = "", ""
for pos_arg, data in pos_only_store.items():
pos_only += pos_arg
if data[0]:
pos_only += ": " + data[0]
if data[1]:
pos_only += " = " + data[1]
pos_only += ", "
for kw_arg, data in kw_only_store.items():
kw_only += kw_arg
if data[0]:
kw_only += ": " + data[0]
if data[1]:
kw_only += " = " + data[1]
kw_only += ", "

return pos_only, kw_only


def _generate_dataclass_init(
Expand Down Expand Up @@ -282,7 +297,7 @@ def _generate_dataclass_init(
params_string += ", "

if prev_kw_only:
params_string += "*, " + prev_kw_only + ", "
params_string += "*, " + prev_kw_only
if kw_only_decorated:
params_string += ", ".join(params) + ", "
elif kw_only_decorated:
Expand Down
69 changes: 69 additions & 0 deletions astroid/nodes/node_classes.py
Expand Up @@ -834,6 +834,75 @@ def format_args(self, *, skippable_names: set[str] | None = None) -> str:
result.append(f"**{self.kwarg}")
return ", ".join(result)

def _get_arguments_data(
self,
) -> tuple[
dict[str, tuple[str | None, str | None]],
dict[str, tuple[str | None, str | None]],
]:
"""Get the arguments as dictionary with information about typing and defaults.
The return tuple contains a dictionary for positional and keyword arguments with their typing
and their default value, if any.
The method follows a similar order as format_args but instead of formatting into a string it
returns the data that is used to do so.
"""
pos_only: dict[str, tuple[str | None, str | None]] = {}
kw_only: dict[str, tuple[str | None, str | None]] = {}

# Setup and match defaults with arguments
positional_only_defaults = []
positional_or_keyword_defaults = self.defaults
if self.defaults:
args = self.args or []
positional_or_keyword_defaults = self.defaults[-len(args) :]
positional_only_defaults = self.defaults[: len(self.defaults) - len(args)]

for index, posonly in enumerate(self.posonlyargs):
annotation, default = self.posonlyargs_annotations[index], None
if annotation is not None:
annotation = annotation.as_string()
if positional_only_defaults:
default = positional_only_defaults[index].as_string()
pos_only[posonly.name] = (annotation, default)

for index, arg in enumerate(self.args):
annotation, default = self.annotations[index], None
if annotation is not None:
annotation = annotation.as_string()
if positional_or_keyword_defaults:
defaults_offset = len(self.args) - len(positional_or_keyword_defaults)
default_index = index - defaults_offset
if (
default_index > -1
and positional_or_keyword_defaults[default_index] is not None
):
default = positional_or_keyword_defaults[default_index].as_string()
pos_only[arg.name] = (annotation, default)

if self.vararg:
annotation = self.varargannotation
if annotation is not None:
annotation = annotation.as_string()
pos_only[self.vararg] = (annotation, None)

for index, kwarg in enumerate(self.kwonlyargs):
annotation = self.kwonlyargs_annotations[index]
if annotation is not None:
annotation = annotation.as_string()
default = self.kw_defaults[index]
if default is not None:
default = default.as_string()
kw_only[kwarg.name] = (annotation, default)

if self.kwarg:
annotation = self.kwargannotation
if annotation is not None:
annotation = annotation.as_string()
kw_only[self.kwarg] = (annotation, None)

return pos_only, kw_only

def default_value(self, argname):
"""Get the default value for an argument.
Expand Down
70 changes: 70 additions & 0 deletions tests/unittest_brain_dataclasses.py
Expand Up @@ -918,6 +918,7 @@ def test_dataclass_with_multiple_inheritance() -> None:
"""Regression test for dataclasses with multiple inheritance.
Reported in https://github.com/PyCQA/pylint/issues/7427
Reported in https://github.com/PyCQA/pylint/issues/7434
"""
first, second, overwritten, overwriting, mixed = astroid.extract_node(
"""
Expand Down Expand Up @@ -991,6 +992,75 @@ class ChildWithMixedParents(BaseParent, NotADataclassParent):
assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"]
assert [a.value for a in mixed_init.args.defaults] == [1, 3]

first = astroid.extract_node(
"""
from dataclasses import dataclass
@dataclass
class BaseParent:
required: bool
@dataclass
class FirstChild(BaseParent):
...
@dataclass
class SecondChild(BaseParent):
optional: bool = False
@dataclass
class GrandChild(FirstChild, SecondChild):
...
GrandChild.__init__ #@
"""
)

first_init: bases.UnboundMethod = next(first.infer())
assert [a.name for a in first_init.args.args] == ["self", "required", "optional"]
assert [a.value for a in first_init.args.defaults] == [False]


@pytest.mark.xfail(reason="Transforms returning Uninferable isn't supported.")
def test_dataclass_non_default_argument_after_default() -> None:
"""Test that a non-default argument after a default argument is not allowed.
This should succeed, but the dataclass brain is a transform
which currently can't return an Uninferable correctly. Therefore, we can't
set the dataclass ClassDef node to be Uninferable currently.
Eventually it can be merged into test_dataclass_with_multiple_inheritance.
"""

impossible = astroid.extract_node(
"""
from dataclasses import dataclass
@dataclass
class BaseParent:
required: bool
@dataclass
class FirstChild(BaseParent):
...
@dataclass
class SecondChild(BaseParent):
optional: bool = False
@dataclass
class ThirdChild:
other: bool = False
@dataclass
class ImpossibleGrandChild(FirstChild, SecondChild, ThirdChild):
...
ImpossibleGrandChild() #@
"""
)

assert next(impossible.infer()) is Uninferable


def test_dataclass_inits_of_non_dataclasses() -> None:
"""Regression test for __init__ mangling for non dataclasses.
Expand Down

0 comments on commit 1ffe400

Please sign in to comment.