Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit b151e80
Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date:   Sat Feb 27 15:46:18 2021 +0100

    Use flag to guard DuplicateBasesError

commit 3bbf4d2
Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date:   Tue Feb 23 01:48:29 2021 +0100

    Add test cases for duplicate bases error

commit 934264d
Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date:   Mon Feb 22 11:00:48 2021 +0100

    Fix duplicate bases error for typing._GenericAlias

    * Fixes pylint-dev#905
  • Loading branch information
cdce8p committed Feb 27, 2021
1 parent 1f862cb commit 8b6ad8e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Expand Up @@ -11,6 +11,10 @@ Release Date: TBA

Closes #895 #899

* Use flag to guard DuplicateBasesError

Closes #905

What's New in astroid 2.5?
============================
Release Date: 2021-02-15
Expand Down
34 changes: 26 additions & 8 deletions astroid/scoped_nodes.py
Expand Up @@ -99,13 +99,20 @@ def _c3_merge(sequences, cls, context):
return None


def clean_duplicates_mro(sequences, cls, context):
def clean_duplicates_mro(
sequences, cls, context, raise_duplicate_bases_error: bool = True
):
for sequence in sequences:
names = [
(node.lineno, node.qname()) if node.name else None for node in sequence
]
last_index = dict(map(reversed, enumerate(names)))
if names and names[0] is not None and last_index[names[0]] != 0:
if (
raise_duplicate_bases_error is True
and names
and names[0] is not None
and last_index[names[0]] != 0
):
raise exceptions.DuplicateBasesError(
message="Duplicates found in MROs {mros} for {cls!r}.",
mros=sequences,
Expand Down Expand Up @@ -2816,7 +2823,7 @@ def slots(self):

def grouped_slots():
# Not interested in object, since it can't have slots.
for cls in self.mro()[:-1]:
for cls in self.mro(raise_duplicate_bases_error=False)[:-1]:
try:
cls_slots = cls._slots()
except NotImplementedError:
Expand Down Expand Up @@ -2871,15 +2878,18 @@ def _inferred_bases(self, context=None):
else:
yield from baseobj.bases

def _compute_mro(self, context=None):
def _compute_mro(self, context=None, raise_duplicate_bases_error: bool = True):
inferred_bases = list(self._inferred_bases(context=context))
bases_mro = []
for base in inferred_bases:
if base is self:
continue

try:
mro = base._compute_mro(context=context)
mro = base._compute_mro(
context=context,
raise_duplicate_bases_error=raise_duplicate_bases_error,
)
bases_mro.append(mro)
except NotImplementedError:
# Some classes have in their ancestors both newstyle and
Expand All @@ -2891,18 +2901,26 @@ def _compute_mro(self, context=None):
bases_mro.append(ancestors)

unmerged_mro = [[self]] + bases_mro + [inferred_bases]
unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context))
unmerged_mro = list(
clean_duplicates_mro(
unmerged_mro, self, context, raise_duplicate_bases_error
)
)
return _c3_merge(unmerged_mro, self, context)

def mro(self, context=None) -> List["ClassDef"]:
def mro(
self, context=None, raise_duplicate_bases_error: bool = True
) -> List["ClassDef"]:
"""Get the method resolution order, using C3 linearization.
:returns: The list of ancestors, sorted by the mro.
:rtype: list(NodeNG)
:raises DuplicateBasesError: Duplicate bases in the same class base
:raises InconsistentMroError: A class' MRO is inconsistent
"""
return self._compute_mro(context=context)
return self._compute_mro(
context=context, raise_duplicate_bases_error=raise_duplicate_bases_error
)

def bool_value(self, context=None):
"""Determine the boolean value of this node.
Expand Down
53 changes: 51 additions & 2 deletions tests/unittest_scoped_nodes.py
Expand Up @@ -1270,8 +1270,18 @@ class NodeBase(object):
assert len(slots) == 3, slots
assert [slot.value for slot in slots] == ["a", "b", "c"]

def assertEqualMro(self, klass, expected_mro):
self.assertEqual([member.name for member in klass.mro()], expected_mro)
def assertEqualMro(
self, klass, expected_mro, raise_duplicate_bases_error: bool = True
):
self.assertEqual(
[
member.name
for member in klass.mro(
raise_duplicate_bases_error=raise_duplicate_bases_error
)
],
expected_mro,
)

@unittest.skipUnless(HAS_SIX, "These tests require the six library")
def test_with_metaclass_mro(self):
Expand Down Expand Up @@ -1436,6 +1446,45 @@ class C(scope.A, scope.B):
)
self.assertEqualMro(cls, ["C", "A", "B", "object"])

@test_utils.require_version("3.7", "3.9")
def test_mro_with_duplicate_generic_alias(self):
"""Catch false positive. Assert no error is thrown."""
cls = builder.extract_node(
"""
from typing import Sized, Hashable
class Derived(Sized, Hashable):
def __init__(self):
self.var = 1
"""
)
self.assertEqualMro(
cls,
["Derived", "_GenericAlias", "_Final", "object"],
raise_duplicate_bases_error=False,
)

@test_utils.require_version("3.9")
def test_mro_with_duplicate_generic_alias_2(self):
cls = builder.extract_node(
"""
from typing import Sized, Hashable
class Derived(Sized, Hashable):
def __init__(self):
self.var = 1
"""
)
self.assertEqualMro(
cls,
[
"Derived",
"_SpecialGenericAlias",
"_BaseGenericAlias",
"_Final",
"object",
],
raise_duplicate_bases_error=False,
)

def test_generator_from_infer_call_result_parent(self):
func = builder.extract_node(
"""
Expand Down

0 comments on commit 8b6ad8e

Please sign in to comment.