From 8b6ad8e5681d3d14a2a3d47e5c2fb05910294a29 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 27 Feb 2021 18:07:49 +0100 Subject: [PATCH] Squashed commit of the following: commit b151e807806d2f188c45ee450a7d84535e75ed4c Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat Feb 27 15:46:18 2021 +0100 Use flag to guard DuplicateBasesError commit 3bbf4d24275eb62ff8268bd2b7bd29926276ceaa Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue Feb 23 01:48:29 2021 +0100 Add test cases for duplicate bases error commit 934264d1730329b99b0d8fa8519f9ea359b58a41 Author: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Mon Feb 22 11:00:48 2021 +0100 Fix duplicate bases error for typing._GenericAlias * Fixes #905 --- ChangeLog | 4 +++ astroid/scoped_nodes.py | 34 +++++++++++++++++----- tests/unittest_scoped_nodes.py | 53 ++++++++++++++++++++++++++++++++-- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/ChangeLog b/ChangeLog index a11e076644..0f8dbca338 100644 --- a/ChangeLog +++ b/ChangeLog @@ -11,6 +11,10 @@ Release Date: TBA Closes #895 #899 +* Use flag to guard DuplicateBasesError + + Closes #905 + What's New in astroid 2.5? ============================ Release Date: 2021-02-15 diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index bfa0b61484..7879cb97a3 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -99,13 +99,20 @@ def _c3_merge(sequences, cls, context): return None -def clean_duplicates_mro(sequences, cls, context): +def clean_duplicates_mro( + sequences, cls, context, raise_duplicate_bases_error: bool = True +): for sequence in sequences: names = [ (node.lineno, node.qname()) if node.name else None for node in sequence ] last_index = dict(map(reversed, enumerate(names))) - if names and names[0] is not None and last_index[names[0]] != 0: + if ( + raise_duplicate_bases_error is True + and names + and names[0] is not None + and last_index[names[0]] != 0 + ): raise exceptions.DuplicateBasesError( message="Duplicates found in MROs {mros} for {cls!r}.", mros=sequences, @@ -2816,7 +2823,7 @@ def slots(self): def grouped_slots(): # Not interested in object, since it can't have slots. - for cls in self.mro()[:-1]: + for cls in self.mro(raise_duplicate_bases_error=False)[:-1]: try: cls_slots = cls._slots() except NotImplementedError: @@ -2871,7 +2878,7 @@ def _inferred_bases(self, context=None): else: yield from baseobj.bases - def _compute_mro(self, context=None): + def _compute_mro(self, context=None, raise_duplicate_bases_error: bool = True): inferred_bases = list(self._inferred_bases(context=context)) bases_mro = [] for base in inferred_bases: @@ -2879,7 +2886,10 @@ def _compute_mro(self, context=None): continue try: - mro = base._compute_mro(context=context) + mro = base._compute_mro( + context=context, + raise_duplicate_bases_error=raise_duplicate_bases_error, + ) bases_mro.append(mro) except NotImplementedError: # Some classes have in their ancestors both newstyle and @@ -2891,10 +2901,16 @@ def _compute_mro(self, context=None): bases_mro.append(ancestors) unmerged_mro = [[self]] + bases_mro + [inferred_bases] - unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context)) + unmerged_mro = list( + clean_duplicates_mro( + unmerged_mro, self, context, raise_duplicate_bases_error + ) + ) return _c3_merge(unmerged_mro, self, context) - def mro(self, context=None) -> List["ClassDef"]: + def mro( + self, context=None, raise_duplicate_bases_error: bool = True + ) -> List["ClassDef"]: """Get the method resolution order, using C3 linearization. :returns: The list of ancestors, sorted by the mro. @@ -2902,7 +2918,9 @@ def mro(self, context=None) -> List["ClassDef"]: :raises DuplicateBasesError: Duplicate bases in the same class base :raises InconsistentMroError: A class' MRO is inconsistent """ - return self._compute_mro(context=context) + return self._compute_mro( + context=context, raise_duplicate_bases_error=raise_duplicate_bases_error + ) def bool_value(self, context=None): """Determine the boolean value of this node. diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index 104cdae1a7..cd77afe7f7 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1270,8 +1270,18 @@ class NodeBase(object): assert len(slots) == 3, slots assert [slot.value for slot in slots] == ["a", "b", "c"] - def assertEqualMro(self, klass, expected_mro): - self.assertEqual([member.name for member in klass.mro()], expected_mro) + def assertEqualMro( + self, klass, expected_mro, raise_duplicate_bases_error: bool = True + ): + self.assertEqual( + [ + member.name + for member in klass.mro( + raise_duplicate_bases_error=raise_duplicate_bases_error + ) + ], + expected_mro, + ) @unittest.skipUnless(HAS_SIX, "These tests require the six library") def test_with_metaclass_mro(self): @@ -1436,6 +1446,45 @@ class C(scope.A, scope.B): ) self.assertEqualMro(cls, ["C", "A", "B", "object"]) + @test_utils.require_version("3.7", "3.9") + def test_mro_with_duplicate_generic_alias(self): + """Catch false positive. Assert no error is thrown.""" + cls = builder.extract_node( + """ + from typing import Sized, Hashable + class Derived(Sized, Hashable): + def __init__(self): + self.var = 1 + """ + ) + self.assertEqualMro( + cls, + ["Derived", "_GenericAlias", "_Final", "object"], + raise_duplicate_bases_error=False, + ) + + @test_utils.require_version("3.9") + def test_mro_with_duplicate_generic_alias_2(self): + cls = builder.extract_node( + """ + from typing import Sized, Hashable + class Derived(Sized, Hashable): + def __init__(self): + self.var = 1 + """ + ) + self.assertEqualMro( + cls, + [ + "Derived", + "_SpecialGenericAlias", + "_BaseGenericAlias", + "_Final", + "object", + ], + raise_duplicate_bases_error=False, + ) + def test_generator_from_infer_call_result_parent(self): func = builder.extract_node( """