-
Notifications
You must be signed in to change notification settings - Fork 986
/
test_dice_loss.py
110 lines (95 loc) · 3.59 KB
/
test_dice_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from parameterized import parameterized
from monai.losses import DiceLoss
TEST_CASE_1 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
{"include_background": True, "do_sigmoid": True},
{
"pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]),
"ground": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"smooth": 1e-6,
},
0.307576,
]
TEST_CASE_2 = [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "do_sigmoid": True},
{
"pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),
"ground": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
"smooth": 1e-4,
},
0.416657,
]
TEST_CASE_3 = [ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True},
{
"pred": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),
"ground": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
"smooth": 0.0,
},
0.0,
]
TEST_CASE_4 = [ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "do_sigmoid": True},
{
"pred": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"ground": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
"smooth": 1e-4,
},
0.435050,
]
TEST_CASE_5 = [ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "do_softmax": True},
{
"pred": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"ground": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
"smooth": 1e-4,
},
0.383713,
]
TEST_CASE_6 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
{"include_background": True, "do_sigmoid": True},
{
"pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]),
"ground": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"smooth": 1e-6,
},
0.307576,
]
TEST_CASE_7 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
{"include_background": True, "do_sigmoid": True, "squared_pred": True},
{
"pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]),
"ground": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"smooth": 1e-5,
},
0.178337,
]
TEST_CASE_8 = [ # shape: (1, 1, 2, 2), (1, 1, 2, 2)
{"include_background": True, "do_sigmoid": True, "jaccard": True},
{
"pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]),
"ground": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
"smooth": 1e-5,
},
-0.059094,
]
class TestDiceLoss(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
def test_shape(self, input_param, input_data, expected_val):
result = DiceLoss(**input_param).forward(**input_data)
self.assertAlmostEqual(result.item(), expected_val, places=5)
if __name__ == "__main__":
unittest.main()