diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc index 55b778bf8e502..c3ae8563370f8 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc @@ -356,42 +356,42 @@ void GraphSendUERecvGradOpKernelLaunchHelper( const Context& ctx, const DenseTensor& out_grad, const DenseTensor& x, - const DenseTensor& e, + const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& compute_type, const std::string& pool_type, DenseTensor* x_grad, - DenseTensor* e_grad, + DenseTensor* y_grad, const DenseTensor* dst_count = nullptr, const DenseTensor* out = nullptr) { const int& index_size = dst_index.dims()[0]; ctx.template Alloc(x_grad); T* x_grad_data = x_grad->data(); - ctx.template Alloc(e_grad); - T* e_grad_data = e_grad->data(); + ctx.template Alloc(y_grad); + T* y_grad_data = y_grad->data(); const auto& x_dims = x.dims(); - const auto& e_dims = e.dims(); - int64_t memset_size_x = 1, memset_size_e = 1; + const auto& y_dims = y.dims(); + int64_t memset_size_x = 1, memset_size_y = 1; int64_t slice_size = 1; for (int i = 0; i < x_dims.size(); i++) { memset_size_x *= x_dims[i]; if (i > 0) slice_size *= x_dims[i]; } - for (int i = 0; i < e_dims.size(); i++) { - memset_size_e *= e_dims[i]; + for (int i = 0; i < y_dims.size(); i++) { + memset_size_y *= y_dims[i]; } const size_t& memset_bytes_x = memset_size_x * sizeof(T); - const size_t& memset_bytes_e = memset_size_e * sizeof(T); + const size_t& memset_bytes_y = memset_size_y * sizeof(T); memset(x_grad_data, 0, memset_bytes_x); - memset(e_grad_data, 0, memset_bytes_e); + memset(y_grad_data, 0, memset_bytes_y); if (index_size == 0) return; const T* out_grad_data = out_grad.data(); const T* x_data = x.data(); - const T* e_data = e.data(); + const T* y_data = y.data(); const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); @@ -399,10 +399,10 @@ void GraphSendUERecvGradOpKernelLaunchHelper( CalculateXGrad(ctx, out_grad_data, x_data, - e_data, + y_data, out_grad.dims(), x_dims, - e_dims, + y_dims, d_index, s_index, compute_type, @@ -415,29 +415,29 @@ void GraphSendUERecvGradOpKernelLaunchHelper( out); CalculateEGrad(out_grad_data, x_data, - e_data, + y_data, x_dims, - e_dims, + y_dims, s_index, d_index, compute_type, pool_type, index_size, - e_grad_data, + y_grad_data, dst_count); } else if (pool_type == "MIN" || pool_type == "MAX") { CalculateXEGradForMinMax(out_grad_data, x_data, - e_data, + y_data, x_dims, - e_dims, + y_dims, d_index, s_index, compute_type, pool_type, index_size, x_grad_data, - e_grad_data, + y_grad_data, out); } } @@ -445,7 +445,7 @@ void GraphSendUERecvGradOpKernelLaunchHelper( template void GraphSendUERecvGradKernel(const Context& ctx, const DenseTensor& x, - const DenseTensor& e, + const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, const paddle::optional& out, @@ -454,20 +454,20 @@ void GraphSendUERecvGradKernel(const Context& ctx, const std::string& compute_type, const std::string& pool_type, DenseTensor* x_grad, - DenseTensor* e_grad) { + DenseTensor* y_grad) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { GraphSendUERecvGradOpKernelLaunchHelper( ctx, out_grad, x, - e, + y, src_index, dst_index, compute_type, pool_type, x_grad, - e_grad, + y_grad, dst_count.get_ptr(), out.get_ptr()); } else if (index_type == phi::DataType::INT64) { @@ -475,13 +475,13 @@ void GraphSendUERecvGradKernel(const Context& ctx, ctx, out_grad, x, - e, + y, src_index, dst_index, compute_type, pool_type, x_grad, - e_grad, + y_grad, dst_count.get_ptr(), out.get_ptr()); } diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc index 6f479c7deb3cf..5c3760657be86 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc @@ -29,7 +29,7 @@ namespace phi { template void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast, const T* x_data, - const T* e_data, + const T* y_data, const IndexT* src_indices, const IndexT* dst_indices, T* output, @@ -43,11 +43,11 @@ void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast, IndexT dst = dst_indices[i]; T* out_off = output + dst * bcast.out_len; const T* x_off = x_data + src * bcast.l_len; - const T* e_off = e_data + i * bcast.r_len; + const T* y_off = y_data + i * bcast.r_len; for (int64_t j = 0; j < bcast.out_len; j++) { int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; - int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; - T val = cfunctor(x_off[x_add], e_off[e_add]); + int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = cfunctor(x_off[x_add], y_off[y_add]); if (val != 0) { #ifdef PADDLE_WITH_MKLML #pragma omp atomic @@ -64,7 +64,7 @@ template void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast, const T* x_data, - const T* e_data, + const T* y_data, const IndexT* src_indices, const IndexT* dst_indices, T* output, @@ -80,17 +80,17 @@ void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast, IndexT dst = dst_indices[i]; T* out_off = output + dst * bcast.out_len; const T* x_off = x_data + src * bcast.l_len; - const T* e_off = e_data + i * bcast.r_len; + const T* y_off = y_data + i * bcast.r_len; bool in_set = existed_dst.find(dst) != existed_dst.end(); for (int64_t j = 0; j < bcast.out_len; j++) { int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; - int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; - T val = cfunctor(x_off[x_add], e_off[e_add]); + int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = cfunctor(x_off[x_add], y_off[y_add]); #ifdef PADDLE_WITH_MKLML #pragma omp critical #endif if (!in_set) { - out_off[j] += val; + out_off[j] = val; } else { out_off[j] = pfunctor(out_off[j], val); } @@ -107,7 +107,7 @@ void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast, template void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& x, - const DenseTensor& e, + const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& compute_type, @@ -135,9 +135,9 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, memset(out_data, 0, memset_bytes); if (index_size == 0) return; - const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims()); + const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims()); const T* x_data = x.data(); - const T* e_data = e.data(); + const T* y_data = y.data(); const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); if (pool_type == "SUM" || pool_type == "MEAN") { @@ -145,7 +145,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphAddFunctor add_functor; GraphSendUERecvSumCpuKernel>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -155,7 +155,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphMulFunctor mul_functor; GraphSendUERecvSumCpuKernel>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -187,7 +187,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphAddFunctor, GraphMinFunctor>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -201,7 +201,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphMulFunctor, GraphMinFunctor>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -218,7 +218,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphAddFunctor, GraphMaxFunctor>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -232,7 +232,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, GraphMulFunctor, GraphMaxFunctor>(bcast_info, x_data, - e_data, + y_data, s_index, d_index, out_data, @@ -246,7 +246,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, template void GraphSendUERecvKernel(const Context& ctx, const DenseTensor& x, - const DenseTensor& e, + const DenseTensor& y, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& compute_type, @@ -259,7 +259,7 @@ void GraphSendUERecvKernel(const Context& ctx, if (index_type == phi::DataType::INT32) { GraphSendUERecvOpKernelLaunchHelper(ctx, x, - e, + y, src_index, dst_index, compute_type, @@ -270,7 +270,7 @@ void GraphSendUERecvKernel(const Context& ctx, } else if (index_type == phi::DataType::INT64) { GraphSendUERecvOpKernelLaunchHelper(ctx, x, - e, + y, src_index, dst_index, compute_type, diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h index 473fd04494238..c11f8c123c40f 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h @@ -16,9 +16,6 @@ #include #include -#include -#include - #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/hostdevice.h" diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py index 414cc6c639714..b001b4d56b9c0 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py @@ -1,5 +1,6 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# Copyright 2022 The DGL team. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index 7279d9d031f90..f5bf1850fd581 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -41,7 +41,7 @@ def send_u_recv(x, Given: - X = [[0, 2, 3], + x = [[0, 2, 3], [1, 4, 5], [2, 6, 7]] @@ -55,7 +55,7 @@ def send_u_recv(x, Then: - Out = [[0, 2, 3], + out = [[0, 2, 3], [2, 8, 10], [1, 4, 5]] @@ -176,7 +176,7 @@ def send_ue_recv(x, This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` - to gather the corresponding data, after computing with `y` in different compute types, then use `dst_index` to + to gather the corresponding data, after computing with `y` in different compute types like add/sub/mul/div, then use `dst_index` to update the corresponding position of output tensor in different pooling types, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. @@ -184,11 +184,11 @@ def send_ue_recv(x, Given: - X = [[0, 2, 3], + x = [[0, 2, 3], [1, 4, 5], [2, 6, 7]] - Y = [1, 1, 1] + y = [1, 1, 1] src_index = [0, 1, 2, 0] @@ -202,7 +202,7 @@ def send_ue_recv(x, Then: - Out = [[1, 3, 4], + out = [[1, 3, 4], [4, 10, 12], [2, 5, 6]] Args: