diff --git a/ChangeLog b/ChangeLog index 9b1ab00a96..7389327cb6 100644 --- a/ChangeLog +++ b/ChangeLog @@ -12,6 +12,10 @@ Release date: TBA .. Put bug fixes that should not wait for a new minor version here +* pyreverse: Show class has-a relationships inferred from the type-hint + +Closes #4744 + * Added ``ignored-parents`` option to the design checker to ignore specific classes from the ``too-many-ancestors`` check (R0901). diff --git a/doc/whatsnew/2.10.rst b/doc/whatsnew/2.10.rst index cadfc3fa64..26b792a6ab 100644 --- a/doc/whatsnew/2.10.rst +++ b/doc/whatsnew/2.10.rst @@ -40,6 +40,8 @@ Extensions Other Changes ============= +* Pyreverse - Show class has-a relationships inferred from type-hints + * Performance of the Similarity checker has been improved. * Added ``time.clock`` to deprecated functions/methods for python 3.3 diff --git a/pylint/pyreverse/utils.py b/pylint/pyreverse/utils.py index 06f8d3b7e2..6eb9d43ccf 100644 --- a/pylint/pyreverse/utils.py +++ b/pylint/pyreverse/utils.py @@ -269,9 +269,9 @@ def infer_node(node: Union[astroid.AssignAttr, astroid.AssignName]) -> set: otherwise return a set of the inferred types using the NodeNG.infer method""" ann = get_annotation(node) - if ann: - return {ann} try: + if ann: + return set(ann.infer()) return set(node.infer()) except astroid.InferenceError: - return set() + return {ann} if ann else set() diff --git a/tests/data/classes_No_Name.dot b/tests/data/classes_No_Name.dot index 8867b4e416..c5eb70f0c4 100644 --- a/tests/data/classes_No_Name.dot +++ b/tests/data/classes_No_Name.dot @@ -3,10 +3,12 @@ charset="utf-8" rankdir=BT "0" [label="{Ancestor|attr : str\lcls_member\l|get_value()\lset_value(value)\l}", shape="record"]; "1" [label="{DoNothing|\l|}", shape="record"]; -"2" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"]; -"3" [label="{Specialization|TYPE : str\lrelation\ltop : str\l|}", shape="record"]; -"3" -> "0" [arrowhead="empty", arrowtail="none"]; -"0" -> "2" [arrowhead="empty", arrowtail="node", style="dashed"]; +"2" [label="{DoNothing2|\l|}", shape="record"]; +"3" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"]; +"4" [label="{Specialization|TYPE : str\lrelation\lrelation2\ltop : str\l|}", shape="record"]; +"4" -> "0" [arrowhead="empty", arrowtail="none"]; +"0" -> "3" [arrowhead="empty", arrowtail="node", style="dashed"]; "1" -> "0" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="cls_member", style="solid"]; -"1" -> "3" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"]; +"1" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"]; +"2" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation2", style="solid"]; } diff --git a/tests/data/clientmodule_test.py b/tests/data/clientmodule_test.py index 40db2e77ef..82deaaf6f8 100644 --- a/tests/data/clientmodule_test.py +++ b/tests/data/clientmodule_test.py @@ -1,5 +1,5 @@ """ docstring for file clientmodule.py """ -from data.suppliermodule_test import Interface, DoNothing +from data.suppliermodule_test import Interface, DoNothing, DoNothing2 class Ancestor: """ Ancestor method """ @@ -23,7 +23,8 @@ class Specialization(Ancestor): TYPE = 'final class' top = 'class' - def __init__(self, value, _id): + def __init__(self, value, _id, relation2: DoNothing2): Ancestor.__init__(self, value) self._id = _id self.relation = DoNothing() + self.relation2 = relation2 diff --git a/tests/data/suppliermodule_test.py b/tests/data/suppliermodule_test.py index 24dc9a02fe..6af30fa08a 100644 --- a/tests/data/suppliermodule_test.py +++ b/tests/data/suppliermodule_test.py @@ -8,3 +8,5 @@ def set_value(self, value): raise NotImplementedError class DoNothing: pass + +class DoNothing2: pass diff --git a/tests/unittest_pyreverse_diadefs.py b/tests/unittest_pyreverse_diadefs.py index fbcf7fae32..4355b703ac 100644 --- a/tests/unittest_pyreverse_diadefs.py +++ b/tests/unittest_pyreverse_diadefs.py @@ -99,6 +99,7 @@ class TestDefaultDiadefGenerator: _should_rels = [ ("association", "DoNothing", "Ancestor"), ("association", "DoNothing", "Specialization"), + ("association", "DoNothing2", "Specialization"), ("implements", "Ancestor", "Interface"), ("specialization", "Specialization", "Ancestor"), ] @@ -142,6 +143,7 @@ def test_known_values1(HANDLER, PROJECT): assert classes == [ (True, "Ancestor"), (True, "DoNothing"), + (True, "DoNothing2"), (True, "Interface"), (True, "Specialization"), ] @@ -170,6 +172,7 @@ def test_known_values3(HANDLER, PROJECT): (True, "data.clientmodule_test.Ancestor"), (True, special), (True, "data.suppliermodule_test.DoNothing"), + (True, "data.suppliermodule_test.DoNothing2"), ] @@ -184,6 +187,7 @@ def test_known_values4(HANDLER, PROJECT): assert classes == [ (True, "Ancestor"), (True, "DoNothing"), + (True, "DoNothing2"), (True, "Specialization"), ] diff --git a/tests/unittest_pyreverse_inspector.py b/tests/unittest_pyreverse_inspector.py index bc37bea713..a927e18ad9 100644 --- a/tests/unittest_pyreverse_inspector.py +++ b/tests/unittest_pyreverse_inspector.py @@ -62,9 +62,9 @@ def test_instance_attrs_resolution(project): klass = project.get_module("data.clientmodule_test")["Specialization"] assert hasattr(klass, "instance_attrs_type") type_dict = klass.instance_attrs_type - assert len(type_dict) == 2 + assert len(type_dict) == 3 keys = sorted(type_dict.keys()) - assert keys == ["_id", "relation"] + assert keys == ["_id", "relation", "relation2"] assert isinstance(type_dict["relation"][0], astroid.bases.Instance), type_dict[ "relation" ] diff --git a/tests/unittest_pyreverse_writer.py b/tests/unittest_pyreverse_writer.py index 9e6cb755bd..8a33eb7a35 100644 --- a/tests/unittest_pyreverse_writer.py +++ b/tests/unittest_pyreverse_writer.py @@ -204,3 +204,21 @@ def test_infer_node_2(mock_infer, mock_get_annotation): mock_infer.return_value = "x" assert infer_node(node) == set("x") assert mock_infer.called + + +def test_infer_node_3(): + """Return a set containing an astroid.ClassDef object when the attribute + has a type annotation""" + node = astroid.extract_node( + """ + class Component: + pass + + class Composite: + def __init__(self, component: Component): + self.component = component + """ + ) + instance_attr = node.instance_attrs.get("component")[0] + assert isinstance(infer_node(instance_attr), set) + assert isinstance(infer_node(instance_attr).pop(), astroid.ClassDef)