From 8f657f742b2627ed8b68598adbd3a96610b91924 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 29 Aug 2022 12:47:03 +0800 Subject: [PATCH] [geometric]Move graph-related incubate api to geometric (#44970) * move incubate to geometric * add paddle.geometric * fix unittest bug * add float16 support for segment op * change reindex and sample neighbors flag name * add heter graph reindex * move sample_neighbors.py to neighbors.py * delete khop_sampler in geometric * delete unused code * change sample_neighbors api input order * fix en doc * fix unittest * fix unittest * change reindex * fix division by 0 * delete unnecessary input argument * delete final_state --- .../kernels/cpu/segment_pool_grad_kernel.cc | 3 +- paddle/phi/kernels/cpu/segment_pool_kernel.cc | 3 +- paddle/phi/kernels/funcs/segment_pooling.cc | 5 + paddle/phi/kernels/funcs/segment_pooling.cu | 7 +- .../kernels/gpu/segment_pool_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/segment_pool_kernel.cu | 3 +- .../kernels/impl/segment_pool_kernel_impl.h | 2 +- .../tests/unittests/test_graph_reindex.py | 271 +++++++++++++++++ .../unittests/test_graph_sample_neighbors.py | 193 ++++++++++++ .../fluid/tests/unittests/test_segment_ops.py | 112 ++++++- python/paddle/geometric/__init__.py | 14 + python/paddle/geometric/math.py | 264 +++++++++++++++++ .../geometric/message_passing/send_recv.py | 27 +- python/paddle/geometric/reindex.py | 279 ++++++++++++++++++ python/paddle/geometric/sampling/__init__.py | 19 ++ python/paddle/geometric/sampling/neighbors.py | 144 +++++++++ .../incubate/operators/graph_reindex.py | 7 +- .../operators/graph_sample_neighbors.py | 6 + python/paddle/incubate/tensor/math.py | 17 ++ python/setup.py.in | 1 + 20 files changed, 1351 insertions(+), 29 deletions(-) create mode 100644 python/paddle/geometric/math.py create mode 100644 python/paddle/geometric/reindex.py create mode 100644 python/paddle/geometric/sampling/__init__.py create mode 100644 python/paddle/geometric/sampling/neighbors.py diff --git a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc index 744ec7805fa60..f96916dfec042 100644 --- a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(segment_pool_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/segment_pool_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_kernel.cc index 541ccd3436548..d7d24f10e2bbf 100644 --- a/paddle/phi/kernels/cpu/segment_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_kernel.cc @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(segment_pool, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/funcs/segment_pooling.cc b/paddle/phi/kernels/funcs/segment_pooling.cc index e6bd371935622..c5d0944d8c190 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cc +++ b/paddle/phi/kernels/funcs/segment_pooling.cc @@ -145,6 +145,7 @@ class SegmentPoolGradFunctor { }; using CPU = phi::CPUContext; +using float16 = phi::dtype::float16; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; @@ -153,6 +154,8 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; @@ -162,6 +165,8 @@ template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu index 1012ca413ed60..0b608367bbe74 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -324,7 +324,7 @@ class SegmentPoolFunctor { const std::string pooltype = "SUM") { if (pooltype == "MEAN") { // Sum the segment id num first - T DimTileSize = 8; + IndexT DimTileSize = 8; auto input_length_size = segment_ids.numel(); auto total_stripe_count = (input_length_size + DimTileSize - 1) / DimTileSize; @@ -440,6 +440,7 @@ class SegmentPoolGradFunctor { }; using GPU = phi::GPUContext; +using float16 = phi::dtype::float16; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; @@ -448,6 +449,8 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; @@ -457,6 +460,8 @@ template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu index 3cfbcc83cb13e..86800ed840528 100644 --- a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(segment_pool_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/segment_pool_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_kernel.cu index 55ab822f5f54e..9da2ef519a42b 100644 --- a/paddle/phi/kernels/gpu/segment_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_kernel.cu @@ -26,4 +26,5 @@ PD_REGISTER_KERNEL(segment_pool, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/segment_pool_kernel_impl.h b/paddle/phi/kernels/impl/segment_pool_kernel_impl.h index c1671a1b37adf..48111913c7822 100644 --- a/paddle/phi/kernels/impl/segment_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/segment_pool_kernel_impl.h @@ -97,7 +97,7 @@ void SegmentKernelLaunchHelper(const Context& dev_ctx, out->Resize({dims}); dev_ctx.template Alloc(out); - T init_value = 0; + T init_value = static_cast(0); if (pooltype == "MAX") { init_value = static_cast(-FLT_MAX); } else if (pooltype == "MIN") { diff --git a/python/paddle/fluid/tests/unittests/test_graph_reindex.py b/python/paddle/fluid/tests/unittests/test_graph_reindex.py index 11078fae5b65d..668d1525c81a7 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_reindex.py +++ b/python/paddle/fluid/tests/unittests/test_graph_reindex.py @@ -181,5 +181,276 @@ def test_reindex_result_static(self): np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) +class TestGeometricGraphReindex(unittest.TestCase): + + def setUp(self): + self.x = np.arange(5).astype("int64") + self.neighbors = np.random.randint(100, size=20).astype("int64") + self.count = np.array([2, 8, 4, 3, 3], dtype="int32") + + # Get numpy result. + out_nodes = list(self.x) + for neighbor in self.neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + self.out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(self.out_nodes)} + self.reindex_src = np.array( + [reindex_dict[node] for node in self.neighbors]) + reindex_dst = [] + for node, c in zip(self.x, self.count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + self.reindex_dst = np.array(reindex_dst, dtype="int64") + self.num_nodes = np.max(np.concatenate([self.x, self.neighbors])) + 1 + + def test_reindex_result(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + neighbors = paddle.to_tensor(self.neighbors) + count = paddle.to_tensor(self.count) + value_buffer = paddle.full([self.num_nodes], -1, dtype="int32") + index_buffer = paddle.full([self.num_nodes], -1, dtype="int32") + + reindex_src, reindex_dst, out_nodes = \ + paddle.geometric.reindex_graph(x, neighbors, count) + np.testing.assert_allclose(self.reindex_src, reindex_src, rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, reindex_dst, rtol=1e-05) + np.testing.assert_allclose(self.out_nodes, out_nodes, rtol=1e-05) + + reindex_src, reindex_dst, out_nodes = \ + paddle.geometric.reindex_graph(x, neighbors, count, + value_buffer, index_buffer) + np.testing.assert_allclose(self.reindex_src, reindex_src, rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, reindex_dst, rtol=1e-05) + np.testing.assert_allclose(self.out_nodes, out_nodes, rtol=1e-05) + + def test_heter_reindex_result(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + neighbors = paddle.to_tensor(self.neighbors) + neighbors = paddle.concat([neighbors, neighbors]) + count = paddle.to_tensor(self.count) + count = paddle.concat([count, count]) + + reindex_src, reindex_dst, out_nodes = \ + paddle.geometric.reindex_graph(x, neighbors, count) + np.testing.assert_allclose(self.reindex_src, + reindex_src[:self.neighbors.shape[0]], + rtol=1e-05) + np.testing.assert_allclose(self.reindex_src, + reindex_src[self.neighbors.shape[0]:], + rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, + reindex_dst[:self.neighbors.shape[0]], + rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, + reindex_dst[self.neighbors.shape[0]:], + rtol=1e-05) + np.testing.assert_allclose(self.out_nodes, out_nodes, rtol=1e-05) + + def test_heter_reindex_result_v2(self): + paddle.disable_static() + x = np.arange(5).astype("int64") + neighbors1 = np.random.randint(100, size=20).astype("int64") + count1 = np.array([2, 8, 4, 3, 3], dtype="int32") + neighbors2 = np.random.randint(100, size=20).astype("int64") + count2 = np.array([4, 5, 1, 6, 4], dtype="int32") + neighbors = np.concatenate([neighbors1, neighbors2]) + counts = np.concatenate([count1, count2]) + + # Get numpy result. + out_nodes = list(x) + for neighbor in neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(out_nodes)} + reindex_src = np.array([reindex_dict[node] for node in neighbors]) + reindex_dst = [] + for count in [count1, count2]: + for node, c in zip(x, count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + reindex_dst = np.array(reindex_dst, dtype="int64") + + reindex_src_, reindex_dst_, out_nodes_ = \ + paddle.geometric.reindex_graph(paddle.to_tensor(x), + paddle.to_tensor(neighbors), + paddle.to_tensor(counts)) + np.testing.assert_allclose(reindex_src, reindex_src_, rtol=1e-05) + np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05) + np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05) + + def test_heter_reindex_result_v3(self): + paddle.disable_static() + x = np.arange(5).astype("int64") + neighbors1 = np.random.randint(100, size=20).astype("int64") + count1 = np.array([2, 8, 4, 3, 3], dtype="int32") + neighbors2 = np.random.randint(100, size=20).astype("int64") + count2 = np.array([4, 5, 1, 6, 4], dtype="int32") + neighbors = np.concatenate([neighbors1, neighbors2]) + count = np.concatenate([count1, count2]) + + # Get numpy result. + out_nodes = list(x) + for neighbor in neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(out_nodes)} + reindex_src = np.array([reindex_dict[node] for node in neighbors]) + reindex_dst = [] + for count in [count1, count2]: + for node, c in zip(x, count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + reindex_dst = np.array(reindex_dst, dtype="int64") + + neighbors = [paddle.to_tensor(neighbors1), paddle.to_tensor(neighbors2)] + count = [paddle.to_tensor(count1), paddle.to_tensor(count2)] + reindex_src_, reindex_dst_, out_nodes_ = \ + paddle.geometric.reindex_heter_graph(paddle.to_tensor(x), + neighbors, count) + np.testing.assert_allclose(reindex_src, reindex_src_, rtol=1e-05) + np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05) + np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05) + + def test_reindex_result_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", + shape=self.x.shape, + dtype=self.x.dtype) + neighbors = paddle.static.data(name="neighbors", + shape=self.neighbors.shape, + dtype=self.neighbors.dtype) + count = paddle.static.data(name="count", + shape=self.count.shape, + dtype=self.count.dtype) + value_buffer = paddle.static.data(name="value_buffer", + shape=[self.num_nodes], + dtype="int32") + index_buffer = paddle.static.data(name="index_buffer", + shape=[self.num_nodes], + dtype="int32") + + reindex_src_1, reindex_dst_1, out_nodes_1 = \ + paddle.geometric.reindex_graph(x, neighbors, count) + reindex_src_2, reindex_dst_2, out_nodes_2 = \ + paddle.geometric.reindex_graph(x, neighbors, count, + value_buffer, index_buffer) + + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'x': + self.x, + 'neighbors': + self.neighbors, + 'count': + self.count, + 'value_buffer': + np.full([self.num_nodes], -1, dtype="int32"), + 'index_buffer': + np.full([self.num_nodes], -1, dtype="int32") + }, + fetch_list=[ + reindex_src_1, reindex_dst_1, out_nodes_1, + reindex_src_2, reindex_dst_2, out_nodes_2 + ]) + reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \ + reindex_dst_2, out_nodes_2 = ret + np.testing.assert_allclose(self.reindex_src, + reindex_src_1, + rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, + reindex_dst_1, + rtol=1e-05) + np.testing.assert_allclose(self.out_nodes, out_nodes_1, rtol=1e-05) + np.testing.assert_allclose(self.reindex_src, + reindex_src_2, + rtol=1e-05) + np.testing.assert_allclose(self.reindex_dst, + reindex_dst_2, + rtol=1e-05) + np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) + + def test_heter_reindex_result_static(self): + paddle.enable_static() + np_x = np.arange(5).astype("int64") + np_neighbors1 = np.random.randint(100, size=20).astype("int64") + np_count1 = np.array([2, 8, 4, 3, 3], dtype="int32") + np_neighbors2 = np.random.randint(100, size=20).astype("int64") + np_count2 = np.array([4, 5, 1, 6, 4], dtype="int32") + np_neighbors = np.concatenate([np_neighbors1, np_neighbors2]) + np_count = np.concatenate([np_count1, np_count2]) + + # Get numpy result. + out_nodes = list(np_x) + for neighbor in np_neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(out_nodes)} + reindex_src = np.array([reindex_dict[node] for node in np_neighbors]) + reindex_dst = [] + for count in [np_count1, np_count2]: + for node, c in zip(np_x, count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + reindex_dst = np.array(reindex_dst, dtype="int64") + + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[5], dtype="int64") + neighbors1 = paddle.static.data(name="neighbors1", + shape=[20], + dtype="int64") + count1 = paddle.static.data(name="count1", shape=[5], dtype="int32") + neighbors2 = paddle.static.data(name="neighbors2", + shape=[20], + dtype="int64") + count2 = paddle.static.data(name="count2", shape=[5], dtype="int32") + value_buffer = paddle.static.data(name="value_buffer", + shape=[5], + dtype="int32") + index_buffer = paddle.static.data(name="index_buffer", + shape=[5], + dtype="int32") + + reindex_src_1, reindex_dst_1, out_nodes_1 = \ + paddle.geometric.reindex_heter_graph(x, + [neighbors1, neighbors2], + [count1, count2]) + reindex_src_2, reindex_dst_2, out_nodes_2 = \ + paddle.geometric.reindex_heter_graph(x, + [neighbors1, neighbors2], + [count1, count2], + value_buffer, index_buffer) + + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'x': np_x, + 'neighbors1': np_neighbors1, + 'count1': np_count1, + 'neighbors2': np_neighbors2, + 'count2': np_count2, + 'value_buffer': np.full([5], -1, dtype="int32"), + 'index_buffer': np.full([5], -1, dtype="int32") + }, + fetch_list=[ + reindex_src_1, reindex_dst_1, out_nodes_1, + reindex_src_2, reindex_dst_2, out_nodes_2 + ]) + + reindex_src_1, reindex_dst_1, out_nodes_1, reindex_src_2, \ + reindex_dst_2, out_nodes_2 = ret + np.testing.assert_allclose(reindex_src, reindex_src_1, rtol=1e-05) + np.testing.assert_allclose(reindex_dst, reindex_dst_1, rtol=1e-05) + np.testing.assert_allclose(out_nodes, out_nodes_1, rtol=1e-05) + np.testing.assert_allclose(reindex_src, reindex_src_2, rtol=1e-05) + np.testing.assert_allclose(reindex_dst, reindex_dst_2, rtol=1e-05) + np.testing.assert_allclose(out_nodes, out_nodes_2, rtol=1e-05) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py index f84513506b324..557d41993b878 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py +++ b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py @@ -221,5 +221,198 @@ def test_sample_result_with_eids(self): fetch_list=[out_neighbors, out_count, out_eids]) +class TestGeometricGraphSampleNeighbors(unittest.TestCase): + + def setUp(self): + num_nodes = 20 + edges = np.random.randint(num_nodes, size=(100, 2)) + edges = np.unique(edges, axis=0) + self.edges_id = np.arange(0, len(edges)).astype("int64") + sorted_edges = edges[np.argsort(edges[:, 1])] + + # Calculate dst index cumsum counts, also means colptr + dst_count = np.zeros(num_nodes) + dst_src_dict = {} + for dst in range(0, num_nodes): + true_index = sorted_edges[:, 1] == dst + dst_count[dst] = np.sum(true_index) + dst_src_dict[dst] = sorted_edges[:, 0][true_index] + dst_count = dst_count.astype("int64") + colptr = np.cumsum(dst_count) + colptr = np.insert(colptr, 0, 0) + + self.row = sorted_edges[:, 0].astype("int64") + self.colptr = colptr.astype("int64") + self.nodes = np.unique(np.random.randint(num_nodes, + size=5)).astype("int64") + self.sample_size = 5 + self.dst_src_dict = dst_src_dict + + def test_sample_result(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + out_neighbors, out_count = paddle.geometric.sample_neighbors( + row, colptr, nodes, sample_size=self.sample_size) + out_count_cumsum = paddle.cumsum(out_count) + for i in range(len(out_count)): + if i == 0: + neighbors = out_neighbors[0:out_count_cumsum[i]] + else: + neighbors = out_neighbors[ + out_count_cumsum[i - 1]:out_count_cumsum[i]] + # Ensure the correct sample size. + self.assertTrue( + out_count[i] == self.sample_size + or out_count[i] == len(self.dst_src_dict[self.nodes[i]])) + # Ensure no repetitive sample neighbors. + self.assertTrue( + neighbors.shape[0] == paddle.unique(neighbors).shape[0]) + # Ensure the correct sample neighbors. + in_neighbors = np.isin(neighbors.numpy(), + self.dst_src_dict[self.nodes[i]]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_fisher_yates_sampling(self): + paddle.disable_static() + if fluid.core.is_compiled_with_cuda(): + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + perm_buffer = paddle.to_tensor(self.edges_id) + + out_neighbors, out_count = paddle.geometric.sample_neighbors( + row, + colptr, + nodes, + perm_buffer=perm_buffer, + sample_size=self.sample_size) + out_count_cumsum = paddle.cumsum(out_count) + for i in range(len(out_count)): + if i == 0: + neighbors = out_neighbors[0:out_count_cumsum[i]] + else: + neighbors = out_neighbors[ + out_count_cumsum[i - 1]:out_count_cumsum[i]] + # Ensure the correct sample size. + self.assertTrue( + out_count[i] == self.sample_size + or out_count[i] == len(self.dst_src_dict[self.nodes[i]])) + # Ensure no repetitive sample neighbors. + self.assertTrue( + neighbors.shape[0] == paddle.unique(neighbors).shape[0]) + # Ensure the correct sample neighbors. + in_neighbors = np.isin(neighbors.numpy(), + self.dst_src_dict[self.nodes[i]]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_sample_result_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data(name="row", + shape=self.row.shape, + dtype=self.row.dtype) + colptr = paddle.static.data(name="colptr", + shape=self.colptr.shape, + dtype=self.colptr.dtype) + nodes = paddle.static.data(name="nodes", + shape=self.nodes.shape, + dtype=self.nodes.dtype) + + out_neighbors, out_count = paddle.geometric.sample_neighbors( + row, colptr, nodes, sample_size=self.sample_size) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'colptr': self.colptr, + 'nodes': self.nodes + }, + fetch_list=[out_neighbors, out_count]) + out_neighbors, out_count = ret + out_count_cumsum = np.cumsum(out_count) + out_neighbors = np.split(out_neighbors, out_count_cumsum)[:-1] + for neighbors, node, count in zip(out_neighbors, self.nodes, + out_count): + self.assertTrue(count == self.sample_size + or count == len(self.dst_src_dict[node])) + self.assertTrue( + neighbors.shape[0] == np.unique(neighbors).shape[0]) + in_neighbors = np.isin(neighbors, self.dst_src_dict[node]) + self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + + def test_raise_errors(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + + def check_eid_error(): + paddle.geometric.sample_neighbors(row, + colptr, + nodes, + sample_size=self.sample_size, + return_eids=True) + + self.assertRaises(ValueError, check_eid_error) + + def test_sample_result_with_eids(self): + paddle.disable_static() + row = paddle.to_tensor(self.row) + colptr = paddle.to_tensor(self.colptr) + nodes = paddle.to_tensor(self.nodes) + eids = paddle.to_tensor(self.edges_id) + perm_buffer = paddle.to_tensor(self.edges_id) + + out_neighbors, out_count, out_eids = paddle.geometric.sample_neighbors( + row, + colptr, + nodes, + eids=eids, + sample_size=self.sample_size, + return_eids=True) + + out_neighbors, out_count, out_eids = paddle.geometric.sample_neighbors( + row, + colptr, + nodes, + eids=eids, + perm_buffer=perm_buffer, + sample_size=self.sample_size, + return_eids=True) + + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + row = paddle.static.data(name="row", + shape=self.row.shape, + dtype=self.row.dtype) + colptr = paddle.static.data(name="colptr", + shape=self.colptr.shape, + dtype=self.colptr.dtype) + nodes = paddle.static.data(name="nodes", + shape=self.nodes.shape, + dtype=self.nodes.dtype) + eids = paddle.static.data(name="eids", + shape=self.edges_id.shape, + dtype=self.nodes.dtype) + + out_neighbors, out_count, out_eids = paddle.geometric.sample_neighbors( + row, + colptr, + nodes, + sample_size=self.sample_size, + eids=eids, + return_eids=True) + exe = paddle.static.Executor(paddle.CPUPlace()) + ret = exe.run(feed={ + 'row': self.row, + 'colptr': self.colptr, + 'nodes': self.nodes, + 'eids': self.edges_id + }, + fetch_list=[out_neighbors, out_count, out_eids]) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_segment_ops.py b/python/paddle/fluid/tests/unittests/test_segment_ops.py index d1a9aa6ae2ff1..69de02b24ca91 100644 --- a/python/paddle/fluid/tests/unittests/test_segment_ops.py +++ b/python/paddle/fluid/tests/unittests/test_segment_ops.py @@ -19,6 +19,7 @@ import numpy as np import paddle +import paddle.fluid.core as core from op_test import OpTest @@ -75,13 +76,13 @@ def compute_segment_min_max(x, segment_ids, pooltype="MAX"): def segment_pool_split(X, SegmentIds, pooltype): if pooltype == "SUM": - return paddle.incubate.tensor.segment_sum(X, SegmentIds) + return paddle.geometric.segment_sum(X, SegmentIds) elif pooltype == "MEAN": - return paddle.incubate.tensor.segment_mean(X, SegmentIds) + return paddle.geometric.segment_mean(X, SegmentIds) elif pooltype == "MIN": - return paddle.incubate.tensor.segment_min(X, SegmentIds) + return paddle.geometric.segment_min(X, SegmentIds) elif pooltype == "MAX": - return paddle.incubate.tensor.segment_max(X, SegmentIds) + return paddle.geometric.segment_max(X, SegmentIds) class TestSegmentOps(OpTest): @@ -279,6 +280,109 @@ def test_dygraph(self): atol=1e-06) +class API_GeometricSegmentOpsTest(unittest.TestCase): + + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[3], dtype='int32') + + res_sum = paddle.geometric.segment_sum(x, y) + res_mean = paddle.geometric.segment_mean(x, y) + res_max = paddle.geometric.segment_max(x, y) + res_min = paddle.geometric.segment_min(x, y) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + data2 = np.array([0, 0, 1], dtype="int32") + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'y': data2 + }, + fetch_list=[res_sum, res_mean, res_max, res_min]) + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + np.testing.assert_allclose(np_res, ret_res, rtol=1e-05, atol=1e-06) + + def test_dygraph(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], + dtype='float32') + y = paddle.to_tensor([0, 0, 1], dtype="int32") + res_sum = paddle.geometric.segment_sum(x, y) + res_mean = paddle.geometric.segment_mean(x, y) + res_max = paddle.geometric.segment_max(x, y) + res_min = paddle.geometric.segment_min(x, y) + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + np.testing.assert_allclose(np_res, + ret_res.numpy(), + rtol=1e-05, + atol=1e-06) + + def test_dygraph_cpu_float16(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], + dtype='float16') + y = paddle.to_tensor([0, 0, 1], dtype="int32") + res_sum = paddle.geometric.segment_sum(x, y) + res_mean = paddle.geometric.segment_mean(x, y) + res_max = paddle.geometric.segment_max(x, y) + res_min = paddle.geometric.segment_min(x, y) + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float16") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float16") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float16") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float16") + + ret = [res_sum, res_mean, res_max, res_min] + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + np.testing.assert_allclose(np_res, + ret_res.numpy(), + rtol=1e-05, + atol=1e-06) + + def test_dygraph_cuda_float16(self): + if core.is_compiled_with_cuda(): + device = paddle.CUDAPlace(0) + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], + dtype='float16') + y = paddle.to_tensor([0, 0, 1], dtype="int32") + res_sum = paddle.geometric.segment_sum(x, y) + res_mean = paddle.geometric.segment_mean(x, y) + res_max = paddle.geometric.segment_max(x, y) + res_min = paddle.geometric.segment_min(x, y) + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float16") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float16") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float16") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float16") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + np.testing.assert_allclose(np_res, + ret_res.numpy(), + rtol=1e-05, + atol=1e-06) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/geometric/__init__.py b/python/paddle/geometric/__init__.py index 2119a257c8ab0..9618bc57a203e 100644 --- a/python/paddle/geometric/__init__.py +++ b/python/paddle/geometric/__init__.py @@ -15,9 +15,23 @@ from .message_passing import send_u_recv # noqa: F401 from .message_passing import send_ue_recv # noqa: F401 from .message_passing import send_uv # noqa: F401 +from .math import segment_sum # noqa: F401 +from .math import segment_mean # noqa: F401 +from .math import segment_min # noqa: F401 +from .math import segment_max # noqa: F401 +from .reindex import reindex_graph # noqa: F401 +from .reindex import reindex_heter_graph # noqa: F401 +from .sampling import sample_neighbors # noqa: F401 __all__ = [ 'send_u_recv', 'send_ue_recv', 'send_uv', + 'segment_sum', + 'segment_mean', + 'segment_min', + 'segment_max', + 'reindex_graph', + 'reindex_heter_graph', + 'sample_neighbors', ] diff --git a/python/paddle/geometric/math.py b/python/paddle/geometric/math.py new file mode 100644 index 0000000000000..7a6db7d10aa99 --- /dev/null +++ b/python/paddle/geometric/math.py @@ -0,0 +1,264 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layer_helper import LayerHelper, _non_static_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode + +__all__ = [] + + +def segment_sum(data, segment_ids, name=None): + r""" + Segment Sum Operator. + + This operator sums the elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\sum_{j} data_{j}$ + where sum is over j such that `segment_ids[j] == i`. + + Args: + data (Tensor): A tensor, available data type float32, float64, int32, int64, float16. + segment_ids (Tensor): A 1-D tensor, which have the same size + with the first dimension of input data. + Available data type is int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.geometric.segment_sum(data, segment_ids) + #Outputs: [[4., 4., 4.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + return _C_ops.segment_pool(data, segment_ids, "SUM")[0] + if _in_legacy_dygraph(): + out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', + "SUM") + return out + + check_variable_and_dtype( + data, "X", ("float32", "float64", "int32", "int64", "float16"), + "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_sum", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op(type="segment_pool", + inputs={ + "X": data, + "SegmentIds": segment_ids + }, + outputs={ + "Out": out, + "SummedIds": summed_ids + }, + attrs={"pooltype": "SUM"}) + return out + + +def segment_mean(data, segment_ids, name=None): + r""" + Segment mean Operator. + + Ihis operator calculate the mean value of input `data` which + with the same index in `segment_ids`. + It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$ + where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number + of all index 'segment_ids[j] == i'. + + Args: + data (tensor): a tensor, available data type float32, float64, int32, int64, float16. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.geometric.segment_mean(data, segment_ids) + #Outputs: [[2., 2., 2.], [4., 5., 6.]] + + """ + + if in_dygraph_mode(): + return _C_ops.segment_pool(data, segment_ids, "MEAN")[0] + if _in_legacy_dygraph(): + out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', + "MEAN") + return out + + check_variable_and_dtype( + data, "X", ("float32", "float64", "int32", "int64", "float16"), + "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_mean", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op(type="segment_pool", + inputs={ + "X": data, + "SegmentIds": segment_ids + }, + outputs={ + "Out": out, + "SummedIds": summed_ids + }, + attrs={"pooltype": "MEAN"}) + return out + + +def segment_min(data, segment_ids, name=None): + r""" + Segment min operator. + + This operator calculate the minimum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\min_{j} data_{j}$ + where min is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64, int32, int64, float16. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.geometric.segment_min(data, segment_ids) + #Outputs: [[1., 2., 1.], [4., 5., 6.]] + + """ + + if in_dygraph_mode(): + return _C_ops.segment_pool(data, segment_ids, "MIN")[0] + if _in_legacy_dygraph(): + out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', + "MIN") + return out + + check_variable_and_dtype( + data, "X", ("float32", "float64", "int32", "int64", "float16"), + "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_min", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op(type="segment_pool", + inputs={ + "X": data, + "SegmentIds": segment_ids + }, + outputs={ + "Out": out, + "SummedIds": summed_ids + }, + attrs={"pooltype": "MIN"}) + return out + + +def segment_max(data, segment_ids, name=None): + r""" + Segment max operator. + + This operator calculate the maximum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\max_{j} data_{j}$ + where max is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64, int32, int64, float16. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.geometric.segment_max(data, segment_ids) + #Outputs: [[3., 2., 3.], [4., 5., 6.]] + + """ + + if in_dygraph_mode(): + return _C_ops.segment_pool(data, segment_ids, "MAX")[0] + if _in_legacy_dygraph(): + out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', + "MAX") + return out + + check_variable_and_dtype( + data, "X", ("float32", "float64", "int32", "int64", "float16"), + "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_max", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op(type="segment_pool", + inputs={ + "X": data, + "SegmentIds": segment_ids + }, + outputs={ + "Out": out, + "SummedIds": summed_ids + }, + attrs={"pooltype": "MAX"}) + return out diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index db2e9e7b21402..03a272aa6af08 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -82,30 +82,26 @@ def send_u_recv(x, the 0th dimension. Examples: - .. code-block:: python import paddle x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out_size = paddle.max(dst_index) + 1 out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum", out_size=out_size) # Outputs: [[0., 2., 3.], [[2., 8., 10.]]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] @@ -233,7 +229,6 @@ def send_ue_recv(x, the 0th dimension. Examples: - .. code-block:: python import paddle @@ -241,16 +236,14 @@ def send_ue_recv(x, x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") y = paddle.to_tensor([1, 1, 1, 1], dtype="float32") indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum") # Outputs: [[1., 3., 4.], [4., 10., 12.], [2., 5., 6.]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") y = paddle.to_tensor([1, 1, 1], dtype="float32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out_size = paddle.max(dst_index) + 1 out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum", out_size=out_size) # Outputs: [[1., 3., 4.], [[4., 10., 12.]]] @@ -258,8 +251,7 @@ def send_ue_recv(x, x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") y = paddle.to_tensor([1, 1, 1], dtype="float32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") - src_index = indexes[:, 0] - dst_index = indexes[:, 1] + src_index, dst_index = indexes[:, 0], indexes[:, 1] out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum") # Outputs: [[1., 3., 4.], [4., 10., 12.], [0., 0., 0.]] @@ -282,7 +274,7 @@ def send_ue_recv(x, y = -y if message_op == "div": message_op = 'mul' - y = 1. / y + y = 1. / (y + 1e-12) # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. @@ -381,7 +373,7 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. - message_op (Tensor): Different message ops for x and y, including `add`, `sub`, `mul` and `div`. + message_op (str): Different message ops for x and y, including `add`, `sub`, `mul` and `div`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -389,7 +381,6 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): out (Tensor): The output tensor. Examples: - .. code-block:: python import paddle @@ -416,7 +407,7 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): y = -y if message_op == 'div': message_op = 'mul' - y = 1. / y + y = 1. / (y + 1e-12) if in_dygraph_mode(): return _C_ops.graph_send_uv(x, y, src_index, dst_index, diff --git a/python/paddle/geometric/reindex.py b/python/paddle/geometric/reindex.py new file mode 100644 index 0000000000000..9580ff5c4ee1f --- /dev/null +++ b/python/paddle/geometric/reindex.py @@ -0,0 +1,279 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode, Variable +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid import core +from paddle import _C_ops, _legacy_C_ops + +__all__ = [] + + +def reindex_graph(x, + neighbors, + count, + value_buffer=None, + index_buffer=None, + name=None): + """ + Reindex Graph API. + + This API is mainly used in Graph Learning domain, which should be used + in conjunction with `graph_sample_neighbors` API. And the main purpose + is to reindex the ids information of the input nodes, and return the + corresponding graph edges after reindex. + + **Notes**: + The number in x should be unique, otherwise it would cause potential errors. + We will reindex all the nodes from 0. + + Take input nodes x = [0, 1, 2] as an example. + If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], + then we know that the neighbors of 0 is [8, 9], the neighbors of 1 + is [0, 4, 7], and the neighbors of 2 is [6, 7]. + Then after graph_reindex, we will have 3 different outputs: + 1. reindex_src: [3, 4, 0, 5, 6, 7, 6] + 2. reindex_dst: [0, 0, 1, 1, 1, 2, 2] + 3. out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] + We can see that the numbers in `reindex_src` and `reindex_dst` is the corresponding index + of nodes in `out_nodes`. + + Args: + x (Tensor): The input nodes which we sample neighbors for. The available + data type is int32, int64. + neighbors (Tensor): The neighbors of the input nodes `x`. The data type + should be the same with `x`. + count (Tensor): The neighbor count of the input nodes `x`. And the + data type should be int32. + value_buffer (Tensor|None): Value buffer for hashtable. The data type should be int32, + and should be filled with -1. Only useful for gpu version. + index_buffer (Tensor|None): Index buffer for hashtable. The data type should be int32, + and should be filled with -1. Only useful for gpu version. + `value_buffer` and `index_buffer` should be both not None + if you want to speed up by using hashtable buffer. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + reindex_src (Tensor): The source node index of graph edges after reindex. + reindex_dst (Tensor): The destination node index of graph edges after reindex. + out_nodes (Tensor): The index of unique input nodes and neighbors before reindex, + where we put the input nodes `x` in the front, and put neighbor + nodes in the back. + + Examples: + + .. code-block:: python + + import paddle + + x = [0, 1, 2] + neighbors = [8, 9, 0, 4, 7, 6, 7] + count = [2, 3, 2] + x = paddle.to_tensor(x, dtype="int64") + neighbors = paddle.to_tensor(neighbors, dtype="int64") + count = paddle.to_tensor(count, dtype="int32") + + reindex_src, reindex_dst, out_nodes = \ + paddle.geometric.reindex_graph(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] + + """ + use_buffer_hashtable = True if value_buffer is not None \ + and index_buffer is not None else False + + if _non_static_mode(): + reindex_src, reindex_dst, out_nodes = \ + _legacy_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer, + "flag_buffer_hashtable", use_buffer_hashtable) + return reindex_src, reindex_dst, out_nodes + + check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex") + check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"), + "graph_reindex") + check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex") + + if use_buffer_hashtable: + check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), + "graph_reindex") + check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"), + "graph_reindex") + + helper = LayerHelper("reindex_graph", **locals()) + reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype) + reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype) + out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type="graph_reindex", + inputs={ + "X": + x, + "Neighbors": + neighbors, + "Count": + count, + "HashTable_Value": + value_buffer if use_buffer_hashtable else None, + "HashTable_Index": + index_buffer if use_buffer_hashtable else None, + }, + outputs={ + "Reindex_Src": reindex_src, + "Reindex_Dst": reindex_dst, + "Out_Nodes": out_nodes + }, + attrs={"flag_buffer_hashtable": use_buffer_hashtable}) + return reindex_src, reindex_dst, out_nodes + + +def reindex_heter_graph(x, + neighbors, + count, + value_buffer=None, + index_buffer=None, + name=None): + """ + Reindex HeterGraph API. + + This API is mainly used in Graph Learning domain, which should be used + in conjunction with `graph_sample_neighbors` API. And the main purpose + is to reindex the ids information of the input nodes, and return the + corresponding graph edges after reindex. + + **Notes**: + The number in x should be unique, otherwise it would cause potential errors. + We support multi-edge-types neighbors reindexing in reindex_heter_graph api. + We will reindex all the nodes from 0. + + Take input nodes x = [0, 1, 2] as an example. + For graph A, suppose we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], + then we know that the neighbors of 0 is [8, 9], the neighbors of 1 + is [0, 4, 7], and the neighbors of 2 is [6, 7]. + For graph B, suppose we have neighbors = [0, 2, 3, 5, 1], and count = [1, 3, 1], + then we know that the neighbors of 0 is [0], the neighbors of 1 is [2, 3, 5], + and the neighbors of 3 is [1]. + We will get following outputs: + 1. reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] + 2. reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] + 3. out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] + + Args: + x (Tensor): The input nodes which we sample neighbors for. The available + data type is int32, int64. + neighbors (list|tuple): The neighbors of the input nodes `x` from different graphs. + The data type should be the same with `x`. + count (list|tuple): The neighbor counts of the input nodes `x` from different graphs. + And the data type should be int32. + value_buffer (Tensor|None): Value buffer for hashtable. The data type should be int32, + and should be filled with -1. Only useful for gpu version. + index_buffer (Tensor|None): Index buffer for hashtable. The data type should be int32, + and should be filled with -1. Only useful for gpu version. + `value_buffer` and `index_buffer` should be both not None + if you want to speed up by using hashtable buffer. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + reindex_src (Tensor): The source node index of graph edges after reindex. + reindex_dst (Tensor): The destination node index of graph edges after reindex. + out_nodes (Tensor): The index of unique input nodes and neighbors before reindex, + where we put the input nodes `x` in the front, and put neighbor + nodes in the back. + + Examples: + + .. code-block:: python + + import paddle + + x = [0, 1, 2] + neighbors_a = [8, 9, 0, 4, 7, 6, 7] + count_a = [2, 3, 2] + x = paddle.to_tensor(x, dtype="int64") + neighbors_a = paddle.to_tensor(neighbors_a, dtype="int64") + count_a = paddle.to_tensor(count_a, dtype="int32") + + neighbors_b = [0, 2, 3, 5, 1] + count_b = [1, 3, 1] + neighbors_b = paddle.to_tensor(neighbors_b, dtype="int64") + count_b = paddle.to_tensor(count_b, dtype="int32") + + neighbors = [neighbors_a, neighbors_b] + count = [count_a, count_b] + reindex_src, reindex_dst, out_nodes = \ + paddle.geometric.reindex_heter_graph(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] + + """ + use_buffer_hashtable = True if value_buffer is not None \ + and index_buffer is not None else False + + if _non_static_mode(): + neighbors = paddle.concat(neighbors, axis=0) + count = paddle.concat(count, axis=0) + reindex_src, reindex_dst, out_nodes = \ + _legacy_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer, + "flag_buffer_hashtable", use_buffer_hashtable) + return reindex_src, reindex_dst, out_nodes + + if isinstance(neighbors, Variable): + neighbors = [neighbors] + if isinstance(count, Variable): + count = [count] + + neighbors = paddle.concat(neighbors, axis=0) + count = paddle.concat(count, axis=0) + + check_variable_and_dtype(x, "X", ("int32", "int64"), "heter_graph_reindex") + check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"), + "graph_reindex") + check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex") + + if use_buffer_hashtable: + check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), + "graph_reindex") + check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"), + "graph_reindex") + + helper = LayerHelper("reindex_heter_graph", **locals()) + reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype) + reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype) + out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype) + neighbors = paddle.concat(neighbors, axis=0) + count = paddle.concat(count, axis=0) + helper.append_op(type="graph_reindex", + inputs={ + "X": + x, + "Neighbors": + neighbors, + "Count": + count, + "HashTable_Value": + value_buffer if use_buffer_hashtable else None, + "HashTable_Index": + index_buffer if use_buffer_hashtable else None, + }, + outputs={ + "Reindex_Src": reindex_src, + "Reindex_Dst": reindex_dst, + "Out_Nodes": out_nodes + }, + attrs={"flag_buffer_hashtable": use_buffer_hashtable}) + return reindex_src, reindex_dst, out_nodes diff --git a/python/paddle/geometric/sampling/__init__.py b/python/paddle/geometric/sampling/__init__.py new file mode 100644 index 0000000000000..282fb8fd56792 --- /dev/null +++ b/python/paddle/geometric/sampling/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .neighbors import sample_neighbors # noqa: F401 + +__all__ = [ + 'sample_neighbors', +] diff --git a/python/paddle/geometric/sampling/neighbors.py b/python/paddle/geometric/sampling/neighbors.py new file mode 100644 index 0000000000000..a9619d54a852e --- /dev/null +++ b/python/paddle/geometric/sampling/neighbors.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle import _C_ops, _legacy_C_ops + +__all__ = [] + + +def sample_neighbors(row, + colptr, + input_nodes, + sample_size=-1, + eids=None, + return_eids=False, + perm_buffer=None, + name=None): + """ + Graph Sample Neighbors API. + + This API is mainly used in Graph Learning domain, and the main purpose is to + provide high performance of graph sampling method. For example, we get the + CSC(Compressed Sparse Column) format of the input graph edges as `row` and + `colptr`, so as to convert graph data into a suitable format for sampling. + `input_nodes` means the nodes we need to sample neighbors, and `sample_sizes` + means the number of neighbors and number of layers we want to sample. + + Besides, we support fisher-yates sampling in GPU version. + + Args: + row (Tensor): One of the components of the CSC format of the input graph, and + the shape should be [num_edges, 1] or [num_edges]. The available + data type is int32, int64. + colptr (Tensor): One of the components of the CSC format of the input graph, + and the shape should be [num_nodes + 1, 1] or [num_nodes + 1]. + The data type should be the same with `row`. + input_nodes (Tensor): The input nodes we need to sample neighbors for, and the + data type should be the same with `row`. + sample_size (int): The number of neighbors we need to sample. Default value is -1, + which means returning all the neighbors of the input nodes. + eids (Tensor): The eid information of the input graph. If return_eids is True, + then `eids` should not be None. The data type should be the + same with `row`. Default is None. + return_eids (bool): Whether to return eid information of sample edges. Default is False. + perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `use_perm_buffer` + is True, then `perm_buffer` should not be None. The data type should + be the same with `row`. If not None, we will use fiser-yates sampling + to speed up. Only useful for gpu version. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out_neighbors (Tensor): The sample neighbors of the input nodes. + out_count (Tensor): The number of sampling neighbors of each input node, and the shape + should be the same with `input_nodes`. + out_eids (Tensor): If `return_eids` is True, we will return the eid information of the + sample edges. + + Examples: + .. code-block:: python + import paddle + # edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4), + # (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8) + row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] + colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] + nodes = [0, 8, 1, 2] + sample_size = 2 + row = paddle.to_tensor(row, dtype="int64") + colptr = paddle.to_tensor(colptr, dtype="int64") + nodes = paddle.to_tensor(nodes, dtype="int64") + out_neighbors, out_count = \ + paddle.geometric.sample_neighbors(row, colptr, nodes, + sample_size=sample_size) + + """ + + if return_eids: + if eids is None: + raise ValueError( + f"`eids` should not be None if `return_eids` is True.") + + use_perm_buffer = True if perm_buffer is not None else False + + if _non_static_mode(): + out_neighbors, out_count, out_eids = _legacy_C_ops.graph_sample_neighbors( + row, colptr, input_nodes, eids, perm_buffer, "sample_size", + sample_size, "return_eids", return_eids, "flag_perm_buffer", + use_perm_buffer) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count + + check_variable_and_dtype(row, "Row", ("int32", "int64"), + "graph_sample_neighbors") + check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"), + "graph_sample_neighbors") + check_variable_and_dtype(input_nodes, "X", ("int32", "int64"), + "graph_sample_neighbors") + if return_eids: + check_variable_and_dtype(eids, "Eids", ("int32", "int64"), + "graph_sample_neighbors") + if use_perm_buffer: + check_variable_and_dtype(perm_buffer, "Perm_Buffer", ("int32", "int64"), + "graph_sample_neighbors") + + helper = LayerHelper("sample_neighbors", **locals()) + out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype) + out_count = helper.create_variable_for_type_inference(dtype=row.dtype) + out_eids = helper.create_variable_for_type_inference(dtype=row.dtype) + helper.append_op(type="graph_sample_neighbors", + inputs={ + "Row": row, + "Col_Ptr": colptr, + "X": input_nodes, + "Eids": eids if return_eids else None, + "Perm_Buffer": perm_buffer if use_perm_buffer else None + }, + outputs={ + "Out": out_neighbors, + "Out_Count": out_count, + "Out_Eids": out_eids + }, + attrs={ + "sample_size": sample_size, + "return_eids": return_eids, + "flag_perm_buffer": use_perm_buffer + }) + if return_eids: + return out_neighbors, out_count, out_eids + return out_neighbors, out_count diff --git a/python/paddle/incubate/operators/graph_reindex.py b/python/paddle/incubate/operators/graph_reindex.py index fd55752db4de8..e7e940c2750cc 100644 --- a/python/paddle/incubate/operators/graph_reindex.py +++ b/python/paddle/incubate/operators/graph_reindex.py @@ -18,8 +18,13 @@ from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops +import paddle.utils.deprecated as deprecated +@deprecated(since="2.4.0", + update_to="paddle.geometric.reindex_graph", + level=1, + reason="paddle.incubate.graph_reindex will be removed in future") def graph_reindex(x, neighbors, count, @@ -121,7 +126,7 @@ def graph_reindex(x, if flag_buffer_hashtable: check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), "graph_reindex") - check_variable_and_dtype(index_buffer, "HashTable_Value", ("int32"), + check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"), "graph_reindex") helper = LayerHelper("graph_reindex", **locals()) diff --git a/python/paddle/incubate/operators/graph_sample_neighbors.py b/python/paddle/incubate/operators/graph_sample_neighbors.py index 3e385de7814ec..b230b2a45d58d 100644 --- a/python/paddle/incubate/operators/graph_sample_neighbors.py +++ b/python/paddle/incubate/operators/graph_sample_neighbors.py @@ -18,8 +18,14 @@ from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops +import paddle.utils.deprecated as deprecated +@deprecated( + since="2.4.0", + update_to="paddle.geometric.sample_neighbors", + level=1, + reason="paddle.incubate.graph_sample_neighbors will be removed in future") def graph_sample_neighbors(row, colptr, input_nodes, diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py index 5b2a9b3c66f16..745df5fccf751 100644 --- a/python/paddle/incubate/tensor/math.py +++ b/python/paddle/incubate/tensor/math.py @@ -16,10 +16,15 @@ from paddle.fluid.data_feeder import check_variable_and_dtype from paddle import _C_ops, _legacy_C_ops from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode +import paddle.utils.deprecated as deprecated __all__ = [] +@deprecated(since="2.4.0", + update_to="paddle.geometric.segment_sum", + level=1, + reason="paddle.incubate.segment_sum will be removed in future") def segment_sum(data, segment_ids, name=None): r""" Segment Sum Operator. @@ -80,6 +85,10 @@ def segment_sum(data, segment_ids, name=None): return out +@deprecated(since="2.4.0", + update_to="paddle.geometric.segment_mean", + level=1, + reason="paddle.incubate.segment_mean will be removed in future") def segment_mean(data, segment_ids, name=None): r""" Segment mean Operator. @@ -142,6 +151,10 @@ def segment_mean(data, segment_ids, name=None): return out +@deprecated(since="2.4.0", + update_to="paddle.geometric.segment_min", + level=1, + reason="paddle.incubate.segment_min will be removed in future") def segment_min(data, segment_ids, name=None): r""" Segment min operator. @@ -204,6 +217,10 @@ def segment_min(data, segment_ids, name=None): return out +@deprecated(since="2.4.0", + update_to="paddle.geometric.segment_max", + level=1, + reason="paddle.incubate.segment_max will be removed in future") def segment_max(data, segment_ids, name=None): r""" Segment max operator. diff --git a/python/setup.py.in b/python/setup.py.in index 66f0575284d8d..2eff5dd7c0ade 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -402,6 +402,7 @@ packages=['paddle', 'paddle.profiler', 'paddle.geometric', 'paddle.geometric.message_passing', + 'paddle.geometric.sampling', ] with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: