Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 18, 2022
1 parent e851920 commit e9c50fb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
38 changes: 27 additions & 11 deletions python/paddle/fluid/tests/unittests/test_graph_reindex.py
Expand Up @@ -165,11 +165,19 @@ def test_reindex_result_static(self):
])
reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \
reindex_dst_2, out_nodes_2 = ret
np.testing.assert_allclose(self.reindex_src, reindex_src_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src,
reindex_src_1,
rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst,
reindex_dst_1,
rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src, reindex_src_2, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst_2, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src,
reindex_src_2,
rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst,
reindex_dst_2,
rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05)


Expand Down Expand Up @@ -206,9 +214,9 @@ def test_reindex_result(self):

reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count)
self.assertTrue(np.allclose(self.reindex_src, reindex_src))
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
np.testing.assert_allclose(self.reindex_src, reindex_src, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst, rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes, rtol=1e-05)

reindex_src, reindex_dst, out_nodes = \
paddle.geometric.graph_reindex(x, neighbors, count,
Expand Down Expand Up @@ -354,11 +362,19 @@ def test_reindex_result_static(self):
])
reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \
reindex_dst_2, out_nodes_2 = ret
np.testing.assert_allclose(self.reindex_src, reindex_src_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src,
reindex_src_1,
rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst,
reindex_dst_1,
rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes_1, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src, reindex_src_2, rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst, reindex_dst_2, rtol=1e-05)
np.testing.assert_allclose(self.reindex_src,
reindex_src_2,
rtol=1e-05)
np.testing.assert_allclose(self.reindex_dst,
reindex_dst_2,
rtol=1e-05)
np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05)

def test_heter_reindex_result_static(self):
Expand Down
26 changes: 13 additions & 13 deletions python/paddle/fluid/tests/unittests/test_segment_ops.py
Expand Up @@ -308,9 +308,7 @@ def test_static(self):
fetch_list=[res_sum, res_mean, res_max, res_min])

for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(np_res, ret_res, atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, ret_res))
np.testing.assert_allclose(np_res, ret_res, rtol=1e-05, atol=1e-06)

def test_dygraph(self):
device = paddle.CPUPlace()
Expand All @@ -331,9 +329,10 @@ def test_dygraph(self):
ret = [res_sum, res_mean, res_max, res_min]

for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(np_res, ret_res.numpy(), atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, ret_res))
np.testing.assert_allclose(np_res,
ret_res.numpy(),
rtol=1e-05,
atol=1e-06)

def test_dygraph_cpu_float16(self):
device = paddle.CPUPlace()
Expand All @@ -353,9 +352,10 @@ def test_dygraph_cpu_float16(self):

ret = [res_sum, res_mean, res_max, res_min]
for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(np_res, ret_res.numpy(), atol=1e-6), "two value is\
{}\n{}, check diff!".format(np_res, ret_res))
np.testing.assert_allclose(np_res,
ret_res.numpy(),
rtol=1e-05,
atol=1e-06)

def test_dygraph_cuda_float16(self):
if core.is_compiled_with_cuda():
Expand All @@ -377,10 +377,10 @@ def test_dygraph_cuda_float16(self):
ret = [res_sum, res_mean, res_max, res_min]

for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret):
self.assertTrue(
np.allclose(np_res, ret_res.numpy(), atol=1e-6),
"two value is\
{}\n{}, check diff!".format(np_res, ret_res))
np.testing.assert_allclose(np_res,
ret_res.numpy(),
rtol=1e-05,
atol=1e-06)


if __name__ == '__main__':
Expand Down

0 comments on commit e9c50fb

Please sign in to comment.