Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle has-a relationships for type-hinted arguments in class diagrams #4745

4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Release date: TBA
..
Put bug fixes that should not wait for a new minor version here

* pyreverse: Show class has-a relationships inferred from the type-hint

Closes #4744

* Added ``ignored-parents`` option to the design checker to ignore specific
classes from the ``too-many-ancestors`` check (R0901).

Expand Down
2 changes: 2 additions & 0 deletions doc/whatsnew/2.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Extensions
Other Changes
=============

* Pyreverse - Show class has-a relationships inferred from type-hints

* Performance of the Similarity checker has been improved.

* Added ``time.clock`` to deprecated functions/methods for python 3.3
Expand Down
6 changes: 3 additions & 3 deletions pylint/pyreverse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def infer_node(node: Union[astroid.AssignAttr, astroid.AssignName]) -> set:
otherwise return a set of the inferred types using the NodeNG.infer method"""

ann = get_annotation(node)
if ann:
return {ann}
try:
if ann:
return set(ann.infer())
return set(node.infer())
except astroid.InferenceError:
return set()
return {ann} if ann else set()
12 changes: 7 additions & 5 deletions tests/data/classes_No_Name.dot
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ charset="utf-8"
rankdir=BT
"0" [label="{Ancestor|attr : str\lcls_member\l|get_value()\lset_value(value)\l}", shape="record"];
"1" [label="{DoNothing|\l|}", shape="record"];
"2" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"];
"3" [label="{Specialization|TYPE : str\lrelation\ltop : str\l|}", shape="record"];
"3" -> "0" [arrowhead="empty", arrowtail="none"];
"0" -> "2" [arrowhead="empty", arrowtail="node", style="dashed"];
"2" [label="{DoNothing2|\l|}", shape="record"];
"3" [label="{Interface|\l|get_value()\lset_value(value)\l}", shape="record"];
"4" [label="{Specialization|TYPE : str\lrelation\lrelation2\ltop : str\l|}", shape="record"];
"4" -> "0" [arrowhead="empty", arrowtail="none"];
"0" -> "3" [arrowhead="empty", arrowtail="node", style="dashed"];
"1" -> "0" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="cls_member", style="solid"];
"1" -> "3" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"];
"1" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation", style="solid"];
"2" -> "4" [arrowhead="diamond", arrowtail="none", fontcolor="green", label="relation2", style="solid"];
}
5 changes: 3 additions & 2 deletions tests/data/clientmodule_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
""" docstring for file clientmodule.py """
from data.suppliermodule_test import Interface, DoNothing
from data.suppliermodule_test import Interface, DoNothing, DoNothing2

class Ancestor:
""" Ancestor method """
Expand All @@ -23,7 +23,8 @@ class Specialization(Ancestor):
TYPE = 'final class'
top = 'class'

def __init__(self, value, _id):
def __init__(self, value, _id, relation2: DoNothing2):
Ancestor.__init__(self, value)
self._id = _id
self.relation = DoNothing()
self.relation2 = relation2
2 changes: 2 additions & 0 deletions tests/data/suppliermodule_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ def set_value(self, value):
raise NotImplementedError

class DoNothing: pass

class DoNothing2: pass
4 changes: 4 additions & 0 deletions tests/unittest_pyreverse_diadefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class TestDefaultDiadefGenerator:
_should_rels = [
("association", "DoNothing", "Ancestor"),
("association", "DoNothing", "Specialization"),
("association", "DoNothing2", "Specialization"),
("implements", "Ancestor", "Interface"),
("specialization", "Specialization", "Ancestor"),
]
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_known_values1(HANDLER, PROJECT):
assert classes == [
(True, "Ancestor"),
(True, "DoNothing"),
(True, "DoNothing2"),
(True, "Interface"),
(True, "Specialization"),
]
Expand Down Expand Up @@ -170,6 +172,7 @@ def test_known_values3(HANDLER, PROJECT):
(True, "data.clientmodule_test.Ancestor"),
(True, special),
(True, "data.suppliermodule_test.DoNothing"),
(True, "data.suppliermodule_test.DoNothing2"),
]


Expand All @@ -184,6 +187,7 @@ def test_known_values4(HANDLER, PROJECT):
assert classes == [
(True, "Ancestor"),
(True, "DoNothing"),
(True, "DoNothing2"),
(True, "Specialization"),
]

Expand Down
4 changes: 2 additions & 2 deletions tests/unittest_pyreverse_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def test_instance_attrs_resolution(project):
klass = project.get_module("data.clientmodule_test")["Specialization"]
assert hasattr(klass, "instance_attrs_type")
type_dict = klass.instance_attrs_type
assert len(type_dict) == 2
assert len(type_dict) == 3
keys = sorted(type_dict.keys())
assert keys == ["_id", "relation"]
assert keys == ["_id", "relation", "relation2"]
assert isinstance(type_dict["relation"][0], astroid.bases.Instance), type_dict[
"relation"
]
Expand Down
18 changes: 18 additions & 0 deletions tests/unittest_pyreverse_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,21 @@ def test_infer_node_2(mock_infer, mock_get_annotation):
mock_infer.return_value = "x"
assert infer_node(node) == set("x")
assert mock_infer.called


def test_infer_node_3():
"""Return a set containing an astroid.ClassDef object when the attribute
has a type annotation"""
node = astroid.extract_node(
"""
class Component:
pass

class Composite:
def __init__(self, component: Component):
self.component = component
"""
)
instance_attr = node.instance_attrs.get("component")[0]
assert isinstance(infer_node(instance_attr), set)
assert isinstance(infer_node(instance_attr).pop(), astroid.ClassDef)