From b941c68993c732bf8e0e0be26bf4edb28e3fbb94 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 16 Aug 2022 18:32:15 +0800 Subject: [PATCH] [geometric]Add paddle.geometric.send_uv API (#44848) * initial commit * fix op maker bug * fix mul grad bug * add unittest * fix add grad bug, add cpu kernel * add paddle.geometric.message_passing * add paddle.geometric.send_uv api, add unittest * add fp16 judgement * fix file typo, move compute_type to message_op * add impl file * fix unittest timeout time * add review revise --- paddle/phi/api/yaml/api.yaml | 10 + paddle/phi/api/yaml/backward.yaml | 11 + paddle/phi/infermeta/multiary.cc | 70 ++++ paddle/phi/infermeta/multiary.h | 7 + .../cpu/graph_send_ue_recv_grad_kernel.cc | 2 +- .../kernels/cpu/graph_send_ue_recv_kernel.cc | 2 +- .../kernels/cpu/graph_send_uv_grad_kernel.cc | 260 ++++++++++++++ .../phi/kernels/cpu/graph_send_uv_kernel.cc | 131 ++++++++ .../kernels/gpu/graph_send_ue_recv_funcs.h | 2 +- .../gpu/graph_send_ue_recv_grad_kernel.cu | 2 +- .../kernels/gpu/graph_send_ue_recv_kernel.cu | 2 +- .../kernels/gpu/graph_send_uv_grad_kernel.cu | 317 ++++++++++++++++++ .../phi/kernels/gpu/graph_send_uv_kernel.cu | 172 ++++++++++ .../phi/kernels/graph_send_uv_grad_kernel.h | 33 ++ paddle/phi/kernels/graph_send_uv_kernel.h | 31 ++ ...ng_impl.h => graph_message_passing_impl.h} | 0 .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/test_graph_send_uv_op.py | 265 +++++++++++++++ python/paddle/geometric/__init__.py | 2 + .../geometric/message_passing/__init__.py | 7 + .../geometric/message_passing/send_recv.py | 115 +++++++ 21 files changed, 1437 insertions(+), 5 deletions(-) create mode 100644 paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/graph_send_uv_kernel.cc create mode 100644 paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/graph_send_uv_kernel.cu create mode 100644 paddle/phi/kernels/graph_send_uv_grad_kernel.h create mode 100644 paddle/phi/kernels/graph_send_uv_kernel.h rename paddle/phi/kernels/impl/{graph_messaage_passing_impl.h => graph_message_passing_impl.h} (100%) create mode 100644 python/paddle/fluid/tests/unittests/test_graph_send_uv_op.py diff --git a/paddle/phi/api/yaml/api.yaml b/paddle/phi/api/yaml/api.yaml index 1156206ee4b51..9e4bc2869bc8b 100644 --- a/paddle/phi/api/yaml/api.yaml +++ b/paddle/phi/api/yaml/api.yaml @@ -98,6 +98,16 @@ func : erf backward : erf_grad +- api : graph_send_uv + args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") + output : Tensor(out) + infer_meta : + func : GraphSendUVInferMeta + kernel : + func : graph_send_uv + data_type : x + backward : graph_send_uv_grad + - api : lgamma args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 53cdc97a716d7..70fbd33bf3613 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -105,6 +105,17 @@ func : erf_grad data_type : out_grad +- backward_api : graph_send_uv_grad + forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out) + args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD") + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : graph_send_uv_grad + data_type : x + - backward_api : lgamma_grad forward : lgamma(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index c0ee3cf02feb1..1657e3cb8adcf 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2499,6 +2499,76 @@ void GraphSendUERecvInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(out_dims_array)); } +void GraphSendUVInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& message_op, + MetaTensor* out) { + auto src_index_dims = src_index.dims(); + if (src_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(src_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Src_index should be 1 when it " + "is 2D, but we get %d", + src_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + src_index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The Src_index should be 1D, when it is not 2D, but we get %d", + src_index_dims.size())); + } + + auto dst_index_dims = dst_index.dims(); + if (dst_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(dst_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Dst_index should be 1 when it " + "is 2D, but we get %d", + dst_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dst_index_dims.size(), + 1, + phi::errors::InvalidArgument("The Dst_index should be 1D, " + "when it is not 2D, but we get %d", + dst_index_dims.size())); + } + + PADDLE_ENFORCE_EQ(src_index_dims[0], + dst_index_dims[0], + phi::errors::InvalidArgument( + "Src_index and Dst_index should have the same shape.")); + + // Infer out's shape according to x and y(need broadcasting condition) + out->set_dtype(x.dtype()); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto x_dims1 = phi::vectorize(x_dims); + auto y_dims1 = phi::vectorize(y_dims); + std::vector x_dims2(x_dims1.begin() + 1, x_dims1.end()); + std::vector y_dims2(y_dims1.begin() + 1, y_dims1.end()); + int max_dim = std::max(x_dims2.size(), y_dims2.size()); + int axis = std::abs(static_cast(x_dims2.size() - y_dims2.size())); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + // Only need to broadcast dimensions other than the 0th dimension. + phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2), + phi::make_ddim(y_dims2), + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]); + out->set_dims(phi::make_ddim(out_dims_array)); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 8d85c9b38c0d5..276a43e3ded40 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -431,4 +431,11 @@ void GraphSendUERecvInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* dst_count); +void GraphSendUVInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& message_op, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc index 95fdc6ff0a9cc..c7b1e3e51853b 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc @@ -24,7 +24,7 @@ #include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc index 74fca002294db..ab9adc3897170 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc @@ -22,7 +22,7 @@ #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" -#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc new file mode 100644 index 0000000000000..4e28acdad3db4 --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc @@ -0,0 +1,260 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +void CalculateGrad(const Context& ctx, + const T* out_grad, + const IndexT* s_index, + const IndexT* d_index, + const phi::DDim& out_grad_dims, + const phi::DDim& x_grad_dims, + const std::string& message_op, + int64_t index_size, + int64_t slice_size, + T* x_grad, + const DenseTensor& out_grad_tensor, + const DenseTensor& y) { + std::vector reduce_idx; + bool reduce = ReduceGrad(out_grad_dims, x_grad_dims, reduce_idx); + + if (message_op == "ADD") { + if (!reduce) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT dst = d_index[i]; + T* x_grad_off = x_grad + dst * slice_size; + const T* out_grad_off = out_grad + i * slice_size; + for (int64_t j = 0; j < slice_size; j++) { + if (out_grad_off[j] != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += out_grad_off[j]; + } + } + } + } else { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, x_grad_dims); + auto out_grad_dims_1 = phi::vectorize(out_grad_dims); + std::vector out_grad_dims_2(out_grad_dims_1.begin() + 1, + out_grad_dims_1.end()); + out_grad_dims_2.emplace(out_grad_dims_2.begin(), x_grad_dims[0]); + DenseTensor x_grad_v2 = phi::Empty(ctx, out_grad_dims_2); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT dst = d_index[i]; + T* x_grad_off = x_grad_v2_data + dst * bcast_info.out_len; + const T* out_grad_off = out_grad + i * bcast_info.out_len; + for (int64_t j = 0; j < bcast_info.out_len; j++) { + if (out_grad_off[j] != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += out_grad_off[j]; + } + } + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } else if (message_op == "MUL") { + const auto& bcast = phi::CalcBCastInfo(y.dims(), out_grad_dims); + const T* y_data = y.data(); + if (!reduce) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + T* x_grad_off = x_grad + dst * bcast.out_len; + const T* y_off = y_data + src * bcast.l_len; + const T* out_grad_off = out_grad + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = y_off[y_add] * out_grad_off[o_add]; + if (val != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += val; + } + } + } + } else { + auto out_grad_dims_1 = phi::vectorize(out_grad_dims); + std::vector out_grad_dims_2(out_grad_dims_1.begin() + 1, + out_grad_dims_1.end()); + out_grad_dims_2.emplace(out_grad_dims_2.begin(), x_grad_dims[0]); + DenseTensor x_grad_v2 = phi::Empty(ctx, out_grad_dims_2); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + T* x_grad_off = x_grad_v2_data + dst * bcast.out_len; + const T* y_off = y_data + src * bcast.l_len; + const T* out_grad_off = out_grad + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = y_off[y_add] * out_grad_off[o_add]; + if (val != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += val; + } + } + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } +} + +template +void GraphSendUVGradOpKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const int64_t& index_size = dst_index.dims()[0]; + + PADDLE_ENFORCE_GT( + index_size, + 0, + errors::InvalidArgument("The first dimension of src_index or dst_index " + "shoule be greater than 0, but received %d.", + index_size)); + + ctx.template Alloc(x_grad); + T* x_grad_data = x_grad->data(); + ctx.template Alloc(y_grad); + T* y_grad_data = y_grad->data(); + const auto& x_grad_dims = x_grad->dims(); + const auto& y_grad_dims = y_grad->dims(); + int64_t memset_size_x = 1, memset_size_y = 1; + int64_t slice_size_x = 1, slice_size_y = 1; + for (int i = 0; i < x_grad_dims.size(); i++) { + memset_size_x *= x_grad_dims[i]; + if (i > 0) slice_size_x *= x_grad_dims[i]; + } + for (int i = 0; i < y_grad_dims.size(); i++) { + memset_size_y *= y_grad_dims[i]; + if (i > 0) slice_size_y *= y_grad_dims[i]; + } + const size_t& memset_bytes_x = memset_size_x * sizeof(T); + const size_t& memset_bytes_y = memset_size_y * sizeof(T); + memset(x_grad_data, 0, memset_bytes_x); + memset(y_grad_data, 0, memset_bytes_y); + + const T* out_grad_data = out_grad.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + const auto& out_grad_dims = out_grad.dims(); + // Calculate X Grad. + CalculateGrad(ctx, + out_grad_data, + d_index, + s_index, + out_grad_dims, + x_grad_dims, + message_op, + index_size, + slice_size_x, + x_grad_data, + out_grad, + y); + // Calcuate Y Grad. + CalculateGrad(ctx, + out_grad_data, + s_index, + d_index, + out_grad_dims, + y_grad_dims, + message_op, + index_size, + slice_size_y, + y_grad_data, + out_grad, + x); +} + +template +void GraphSendUVGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const DenseTensor& out_grad, + const std::string& message_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUVGradOpKernelLaunchHelper( + ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad); + } else if (index_type == phi::DataType::INT64) { + GraphSendUVGradOpKernelLaunchHelper( + ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_uv_grad, + CPU, + ALL_LAYOUT, + phi::GraphSendUVGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/graph_send_uv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_uv_kernel.cc new file mode 100644 index 0000000000000..2183eb2a4c593 --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_uv_kernel.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_uv_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" + +namespace phi { + +template +void GraphSendUVCpuKernel(const BroadCastInfo& bcast, + const T* x_data, + const T* y_data, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + int64_t index_size, + ComputeFunctor cfunctor) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = src_indices[i]; + IndexT dst = dst_indices[i]; + T* out_off = output + i * bcast.out_len; + const T* x_off = x_data + src * bcast.l_len; + const T* y_off = y_data + dst * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = cfunctor(x_off[x_add], y_off[y_add]); + out_off[j] = val; + } + } +} + +template +void GraphSendUVOpKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* out) { + const int& index_size = src_index.dims()[0]; + PADDLE_ENFORCE_GT( + index_size, + 0, + errors::InvalidArgument("The first dimension of src_index or dst_index " + "shoule be greater than 0, but received %d.", + index_size)); + + auto out_dims = out->dims(); + int64_t memset_size = 1; + for (int i = 0; i < out_dims.size(); i++) { + memset_size *= out_dims[i]; + } + ctx.template Alloc(out); + T* out_data = out->data(); + + const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + if (message_op == "ADD") { + GraphAddFunctor add_functor; + GraphSendUVCpuKernel>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + add_functor); + } else if (message_op == "MUL") { + GraphMulFunctor mul_functor; + GraphSendUVCpuKernel>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + mul_functor); + } +} + +template +void GraphSendUVKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* out) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUVOpKernelLaunchHelper( + ctx, x, y, src_index, dst_index, message_op, out); + } else if (index_type == phi::DataType::INT64) { + GraphSendUVOpKernelLaunchHelper( + ctx, x, y, src_index, dst_index, message_op, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_uv, + CPU, + ALL_LAYOUT, + phi::GraphSendUVKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h index 49b48b5397538..5ae9393dba0d7 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h @@ -19,7 +19,7 @@ #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/hostdevice.h" -#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu index cb3d5591a7be6..c5d5fb7196fb2 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu @@ -21,7 +21,7 @@ #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" #include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" -#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu index f339387f0bbfc..7351c562dff9d 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu @@ -15,7 +15,7 @@ #include "paddle/phi/kernels/graph_send_ue_recv_kernel.h" #include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" #include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" -#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" #include #include diff --git a/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu new file mode 100644 index 0000000000000..5b8d7b28dcc29 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu @@ -0,0 +1,317 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +__global__ void GraphSendUVGradCUDAKernel(const T* out_grad, + const IndexT* src_indices, + const IndexT* dst_indices, + int64_t index_size, + int64_t slice_size, + T* x_grad) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* out_grad_off = out_grad + ty * slice_size; + T* x_grad_off = x_grad + dst * slice_size; + while (tx < slice_size) { + paddle::platform::CudaAtomicAdd(x_grad_off + tx, out_grad_off[tx]); + tx += stride_x; + } + ty += stride_y; + } +} + +template +void CalculateGrad(const Context& ctx, + const T* out_grad, + const IndexT* s_index, + const IndexT* d_index, + const phi::DDim& out_grad_dims, + const phi::DDim& x_grad_dims, + const std::string& message_op, + int64_t index_size, + int64_t slice_size, + T* x_grad, + const DenseTensor& out_grad_tensor, + const DenseTensor& y) { + std::vector reduce_idx; + bool reduce = ReduceGrad(out_grad_dims, x_grad_dims, reduce_idx); + + if (message_op == "ADD") { + if (!reduce) { + const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (slice_size + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid_tmp(nbx, nby); + const dim3 block_tmp(ntx, nty); + GraphSendUVGradCUDAKernel + <<>>( + out_grad, d_index, s_index, index_size, slice_size, x_grad); + } else { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, x_grad_dims); + auto out_grad_dims_1 = phi::vectorize(out_grad_dims); + std::vector out_grad_dims_2(out_grad_dims_1.begin() + 1, + out_grad_dims_1.end()); + out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]); + DenseTensor x_grad_v2 = phi::Empty(ctx, out_grad_dims_2); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + + const int ntx = + FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (bcast_info.out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid_tmp(nbx, nby); + const dim3 block_tmp(ntx, nty); + GraphSendUVGradCUDAKernel + <<>>(out_grad, + d_index, + s_index, + index_size, + bcast_info.out_len, + x_grad_v2_data); + + // Run reduce sum + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } else if (message_op == "MUL") { + const auto& bcast_info = phi::CalcBCastInfo(y.dims(), out_grad_dims); + thrust::device_vector l_bcastoff, r_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff); + } + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid_(nbx, nby); + const dim3 block_(ntx, nty); + funcs::MultiplyFunctor mul_functor; + GraphSendUERecvSumCUDAFunctor sum_functor; + const T* y_data = y.data(); + if (!reduce) { + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + y_data, + out_grad, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + sum_functor); + } else { + auto out_grad_dims_1 = phi::vectorize(out_grad_dims); + std::vector out_grad_dims_2(out_grad_dims_1.begin() + 1, + out_grad_dims_1.end()); + out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]); + DenseTensor x_grad_v2 = phi::Empty(ctx, out_grad_dims_2); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + y_data, + out_grad, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad_v2_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + sum_functor); + // Run reduce_sum + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } +} + +template +void GraphSendUVGradOpCUDAKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const int64_t& index_size = dst_index.dims()[0]; + PADDLE_ENFORCE_GT( + index_size, + 0, + errors::InvalidArgument("The first dimension of src_index or dst_index " + "shoule be greater than 0, but received %d.", + index_size)); + + ctx.template Alloc(x_grad); + T* x_grad_data = x_grad->data(); + ctx.template Alloc(y_grad); + T* y_grad_data = y_grad->data(); + const auto& x_grad_dims = x_grad->dims(); + const auto& y_grad_dims = y_grad->dims(); + int64_t memset_size_x = 1, memset_size_y = 1; + int64_t slice_size_x = 1, slice_size_y = 1; + for (int i = 0; i < x_grad_dims.size(); i++) { + memset_size_x *= x_grad_dims[i]; + if (i > 0) slice_size_x *= x_grad_dims[i]; + } + for (int i = 0; i < y_grad_dims.size(); i++) { + memset_size_y *= y_grad_dims[i]; + if (i > 0) slice_size_y *= y_grad_dims[i]; + } + const size_t& memset_bytes_x = memset_size_x * sizeof(T); + const size_t& memset_bytes_y = memset_size_y * sizeof(T); +#ifdef PADDLE_WITH_HIP + hipMemset(x_grad_data, 0, memset_bytes_x); + hipMemset(y_grad_data, 0, memset_bytes_y); +#else + cudaMemset(x_grad_data, 0, memset_bytes_x); + cudaMemset(y_grad_data, 0, memset_bytes_y); +#endif + + const T* out_grad_data = out_grad.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + // Calculate X grad. + const auto& out_grad_dims = out_grad.dims(); + CalculateGrad(ctx, + out_grad_data, + s_index, + d_index, + out_grad_dims, + x_grad_dims, + message_op, + index_size, + slice_size_x, + x_grad_data, + out_grad, + y); + // Calculate Y grad. + CalculateGrad(ctx, + out_grad_data, + d_index, + s_index, + out_grad_dims, + y_grad_dims, + message_op, + index_size, + slice_size_y, + y_grad_data, + out_grad, + x); +} + +template +void GraphSendUVGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const DenseTensor& out_grad, + const std::string& message_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUVGradOpCUDAKernelLaunchHelper( + ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad); + } else if (index_type == phi::DataType::INT64) { + GraphSendUVGradOpCUDAKernelLaunchHelper( + ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_uv_grad, + GPU, + ALL_LAYOUT, + phi::GraphSendUVGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu new file mode 100644 index 0000000000000..f1e4581773f54 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_uv_kernel.h" +#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_message_passing_impl.h" + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +template +__global__ void GraphSendUVCUDAKernel(const T* x_data, + const T* y_data, + const IndexT* src_indices, + const IndexT* dst_indices, + const int64_t* xbcast_off, + const int64_t* ybcast_off, + T* output, + int64_t index_size, + int64_t x_len, + int64_t y_len, + int64_t out_len, + bool use_bcast, + ComputeFunctor cfunctor) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + src * x_len; + const T* y_off = y_data + dst * y_len; + T* out_off = output + ty * out_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? xbcast_off[tx] : tx; + int64_t y_add = use_bcast ? ybcast_off[tx] : tx; + T val = cfunctor(x_off[x_add], y_off[y_add]); + out_off[tx] = val; + tx += stride_x; + } + ty += stride_y; + } +} + +template +void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* out) { + const int64_t& index_size = src_index.dims()[0]; + PADDLE_ENFORCE_GT( + index_size, + 0, + errors::InvalidArgument("The first dimension of src_index or dst_index " + "shoule be greater than 0, but received %d.", + index_size)); + + auto out_dims = out->dims(); + int64_t memset_size = 1; + for (int i = 0; i < out_dims.size(); i++) { + memset_size *= out_dims[i]; + } + ctx.template Alloc(out); + T* out_data = out->data(); + + const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + + thrust::device_vector x_bcastoff, y_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, x_bcastoff, y_bcastoff); + } + + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid(nbx, nby); + const dim3 block(ntx, nty); + if (message_op == "ADD") { + funcs::AddFunctor add_functor; + GraphSendUVCUDAKernel> + <<>>( + x_data, + y_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(y_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + add_functor); + } else if (message_op == "MUL") { + funcs::MultiplyFunctor mul_functor; + GraphSendUVCUDAKernel> + <<>>( + x_data, + y_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(y_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor); + } +} + +template +void GraphSendUVKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* out) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUVOpCUDAKernelLaunchHelper( + ctx, x, y, src_index, dst_index, message_op, out); + } else if (index_type == phi::DataType::INT64) { + GraphSendUVOpCUDAKernelLaunchHelper( + ctx, x, y, src_index, dst_index, message_op, out); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_uv, + GPU, + ALL_LAYOUT, + phi::GraphSendUVKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/graph_send_uv_grad_kernel.h b/paddle/phi/kernels/graph_send_uv_grad_kernel.h new file mode 100644 index 0000000000000..fa2285627a4b7 --- /dev/null +++ b/paddle/phi/kernels/graph_send_uv_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphSendUVGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const DenseTensor& out_grad, + const std::string& message_op, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/graph_send_uv_kernel.h b/paddle/phi/kernels/graph_send_uv_kernel.h new file mode 100644 index 0000000000000..7b723122c1a7f --- /dev/null +++ b/paddle/phi/kernels/graph_send_uv_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphSendUVKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/impl/graph_messaage_passing_impl.h b/paddle/phi/kernels/impl/graph_message_passing_impl.h similarity index 100% rename from paddle/phi/kernels/impl/graph_messaage_passing_impl.h rename to paddle/phi/kernels/impl/graph_message_passing_impl.h diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8754694609dc2..fd90071d0cf74 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1563,6 +1563,7 @@ set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_uv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_uv_op.py new file mode 100644 index 0000000000000..3d0cc3a57c6b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_send_uv_op.py @@ -0,0 +1,265 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard + +from op_test import OpTest + + +def compute_graph_send_uv(inputs, attributes): + x = inputs['x'] + y = inputs['y'] + src_index = inputs['src_index'] + dst_index = inputs['dst_index'] + message_op = attributes['message_op'] + + gather_x = x[src_index] + gather_y = y[dst_index] + + # Calculate forward output. + if message_op == "ADD": + results = gather_x + gather_y + elif message_op == "MUL": + results = gather_x * gather_y + + return results + + +def graph_send_uv_wrapper(x, y, src_index, dst_index, message_op="add"): + return paddle.geometric.send_uv(x, y, src_index, dst_index, + message_op.lower()) + + +class TestGraphSendUVOp(OpTest): + + def setUp(self): + paddle.enable_static() + self.python_api = graph_send_uv_wrapper + self.python_out_sig = ['out'] + self.op_type = "graph_send_uv" + self.set_config() + self.inputs = { + 'x': self.x, + 'y': self.y, + 'src_index': self.src_index, + 'dst_index': self.dst_index + } + self.attrs = {'message_op': self.message_op} + out = compute_graph_send_uv(self.inputs, self.attrs) + self.outputs = {'out': out} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['x', 'y'], 'out', check_eager=True) + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestCase1(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestCase2(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((100, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestCase3(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((100, 20)).astype("float64") + self.y = np.random.random((100, 1)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestCase4(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((100, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestCase5(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((100, 20)).astype("float64") + self.y = np.random.random((100, 1)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestCase6(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((10, 10, 1)).astype("float64") + self.y = np.random.random((10, 10, 10)) + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestCase7(TestGraphSendUVOp): + + def set_config(self): + self.x = np.random.random((10, 10, 1)).astype("float64") + self.y = np.random.random((10, 10, 10)) + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class API_GeometricSendUVTest(unittest.TestCase): + + def test_compute_all_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + y = paddle.to_tensor([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + + res_add = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="add") + res_sub = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="sub") + res_mul = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="mul") + res_div = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="div") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]], + dtype="float32") + np_sub = np.array([[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]], + dtype="float32") + np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]], + dtype="float32") + np_div = np.array( + [[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6], [1, 2, 7 / 4], [0, 2, 1.5]], + dtype="float32") + + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + y = paddle.static.data(name="y", shape=[3, 3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[4], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32") + res_add = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="add") + res_sub = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="sub") + res_mul = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="mul") + res_div = paddle.geometric.send_uv(x, + y, + src_index, + dst_index, + message_op="div") + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + data2 = np.array([[1, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32") + data3 = np.array([0, 1, 2, 0], dtype="int32") + data4 = np.array([1, 2, 1, 0], dtype="int32") + + np_add = np.array([[2, 5, 7], [5, 9, 11], [4, 9, 11], [1, 3, 5]], + dtype="float32") + np_sub = np.array( + [[-2, -1, -1], [-3, -1, -1], [0, 3, 3], [-1, 1, 1]], + dtype="float32") + np_mul = np.array([[0, 6, 12], [4, 20, 30], [4, 18, 28], [0, 2, 6]], + dtype="float32") + np_div = np.array([[0, 2 / 3, 0.75], [0.25, 0.8, 5 / 6], + [1, 2, 7 / 4], [0, 2, 1.5]], + dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'y': data2, + 'src': data3, + 'dst': data4, + }, + fetch_list=[res_add, res_sub, res_mul, res_div]) + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], + ret): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_compute_all_dygraph() diff --git a/python/paddle/geometric/__init__.py b/python/paddle/geometric/__init__.py index 65f9335c287c9..2119a257c8ab0 100644 --- a/python/paddle/geometric/__init__.py +++ b/python/paddle/geometric/__init__.py @@ -14,8 +14,10 @@ from .message_passing import send_u_recv # noqa: F401 from .message_passing import send_ue_recv # noqa: F401 +from .message_passing import send_uv # noqa: F401 __all__ = [ 'send_u_recv', 'send_ue_recv', + 'send_uv', ] diff --git a/python/paddle/geometric/message_passing/__init__.py b/python/paddle/geometric/message_passing/__init__.py index ea6ad1c99a518..f215e5be74a48 100644 --- a/python/paddle/geometric/message_passing/__init__.py +++ b/python/paddle/geometric/message_passing/__init__.py @@ -14,3 +14,10 @@ from .send_recv import send_u_recv # noqa: F401 from .send_recv import send_ue_recv # noqa: F401 +from .send_recv import send_uv # noqa: F401 + +__all__ = [ + 'send_u_recv', + 'send_ue_recv', + 'send_uv', +] diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index bfe63f1f04d73..de8fd3b005e29 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -21,6 +21,8 @@ from .utils import convert_out_size_to_list, get_out_size_tensor_inputs, reshape_lhs_rhs +__all__ = [] + def send_u_recv(x, src_index, @@ -336,3 +338,116 @@ def send_ue_recv(x, }, attrs=attrs) return out + + +def send_uv(x, y, src_index, dst_index, message_op="add", name=None): + """ + + Graph Learning message passing api. + + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + consumption in the process of message passing. Take `x` as the source node feature tensor, take `y` as + the destination node feature tensor. Then we use `src_index` and `dst_index` to gather the corresponding data, + and then compute the edge features in different message_ops like `add`, `sub`, `mul`, `div`. + + .. code-block:: text + + Given: + + x = [[0, 2, 3], + [1, 4, 5], + [2, 6, 7]] + + y = [[0, 1, 2], + [2, 3, 4], + [4, 5, 6]] + + src_index = [0, 1, 2, 0] + + dst_index = [1, 2, 1, 0] + + message_op = "add" + + Then: + + out = [[2, 5, 7], + [5, 9, 11], + [4, 9, 11], + [0, 3, 5]] + + Args: + x (Tensor): The source node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. + y (Tensor): The destination node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. + src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. + message_op (Tensor): Different message ops for x and y, including `add`, `sub`, `mul` and `div`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output tensor. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + y = paddle.to_tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.geometric.send_uv(x, y, src_index, dst_index, message_op="add") + # Outputs: [[2., 5., 7.], [5., 9., 11.], [4., 9., 11.], [0., 3., 5.]] + + """ + + if message_op not in ['add', 'sub', 'mul', 'div']: + raise ValueError( + "message_op should be `add`, `sub`, `mul`, `div`, but received %s" % + message_op) + + x, y = reshape_lhs_rhs(x, y) + + if message_op == 'sub': + message_op = 'add' + y = -y + if message_op == 'div': + message_op = 'mul' + y = 1. / y + + if in_dygraph_mode(): + return _C_ops.final_state_graph_send_uv(x, y, src_index, dst_index, + message_op.upper()) + else: + if _in_legacy_dygraph(): + return _C_ops.graph_send_uv(x, y, src_index, dst_index, + "message_op", message_op.upper()) + else: + helper = LayerHelper("send_uv", **locals()) + check_variable_and_dtype( + x, 'x', ['int32', 'int64', 'float32', 'float64', 'float16'], + 'graph_send_uv') + check_variable_and_dtype( + y, 'y', ['int32', 'int64', 'float32', 'float64', 'float16'], + 'graph_send_uv') + check_variable_and_dtype(src_index, 'src_index', ['int32', 'int64'], + 'graph_send_uv') + check_variable_and_dtype(dst_index, 'dst_index', ['int32', 'int64'], + 'graph_send_uv') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + inputs = { + 'x': x, + 'y': y, + 'src_index': src_index, + 'dst_index': dst_index + } + attrs = {'message_op': message_op.upper()} + helper.append_op(type="graph_send_uv", + inputs=inputs, + attrs=attrs, + outputs={"out": out}) + return out