diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index e9ba861c3b88b..b954ecab704b4 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -64,9 +64,9 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("Out", "Output tensor of graph_send_recv op."); AddOutput("Dst_count", - "Count tensor of Dst_index, mainly for MEAN pool_type.") + "Count tensor of Dst_index, mainly for MEAN reduce_op.") .AsIntermediate(); - AddAttr("pool_type", + AddAttr("reduce_op", "(string, default 'SUM')" "Define different pool types to receive the result " "tensors of Dst_index.") @@ -81,7 +81,7 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Graph Learning Send_Recv combine operator. -$Out = Recv(Send(X, Src_index), Dst_index, pool_type)$ +$Out = Recv(Send(X, Src_index), Dst_index, reduce_op)$ This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. @@ -105,12 +105,12 @@ class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Dst_index", this->Input("Dst_index")); op->SetInput("X", this->Input("X")); - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") { op->SetInput("Dst_count", this->Output("Dst_count")); } - if (PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || - PADDLE_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" || + PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") { op->SetInput("Out", this->Output("Out")); } diff --git a/paddle/fluid/operators/graph_send_ue_recv_op.cc b/paddle/fluid/operators/graph_send_ue_recv_op.cc new file mode 100644 index 0000000000000..af16609df3ebd --- /dev/null +++ b/paddle/fluid/operators/graph_send_ue_recv_op.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class GraphSendUERecvOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class GraphSendUERecvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), in_dims); + auto y_dims = ctx->GetInputDim("Y"); + ctx->SetOutputDim(framework::GradVarName("Y"), y_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class GraphSendUERecvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor with data type float32, float64, int32, int64."); + AddInput("Y", + "The input edge weight tensor, data type should be same with X"); + AddInput("Src_index", "The source index tensor."); + AddInput("Dst_index", "The destination index tensor."); + AddInput("Out_size", + "(Tensor, optional). The 0th dimension of the output." + "It has a higher priority than Attr(out_size).") + .AsDispensable(); + AddOutput("Out", "Output tensor of graph_send_ue_recv op."); + AddOutput("Dst_count", + "Count tensor of Dst_index, mainly for MEAN reduce_op.") + .AsIntermediate(); + AddAttr("message_op", + "(string, default 'ADD')" + "Define differenct computation types between X and E.") + .SetDefault("ADD") + .InEnum({"ADD", "MUL"}); + AddAttr("reduce_op", + "(string, default 'SUM')" + "Define different pool types to receive the result " + "tensors of Dst_index.") + .SetDefault("SUM") + .InEnum({"SUM", "MEAN", "MIN", "MAX"}); + AddAttr>( + "out_size", + "(vector, default {0})" + "Define the first dimension of Output tensor." + "If set default {0}, then the shape of Out is the same with X.") + .SetDefault({0}); + AddComment(R"DOC( +Graph Learning Send_UE_Recv combine operator. + +$Out = Recv(Compute(Send(X, Src_index), Y, message_op), Dst_index, reduce_op)$ + +This operator is mainly used in Graph Learning domain, and the main purpose is to reduce +intermediate memory consumption in the process of message passing. + +Take `X` as the input tensor, we first use `src_index` to gather corresponding data. +Then the gather data should compute with `Y` in different message_ops, like add, sub, mul, and div, +and get the computation result. Then, use `dst_index` to update the corresponding position of output +tensor in different pooling types, like sum, mean, max, or min. + +)DOC"); + } +}; + +template +class GraphSendUERecvGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("graph_send_ue_recv_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("Src_index", this->Input("Src_index")); + op->SetInput("Dst_index", this->Input("Dst_index")); + + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MEAN") { + op->SetInput("Dst_count", this->Output("Dst_count")); + } + + if (PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MIN" || + PADDLE_GET_CONST(std::string, this->GetAttr("reduce_op")) == "MAX") { + op->SetInput("Out", this->Output("Out")); + } + + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(graph_send_ue_recv, + GraphSendUERecvInferShapeFunctor, + PD_INFER_META(phi::GraphSendUERecvInferMeta)); +REGISTER_OPERATOR(graph_send_ue_recv, + ops::GraphSendUERecvOP, + ops::GraphSendUERecvOpMaker, + ops::GraphSendUERecvGradOpMaker, + ops::GraphSendUERecvGradOpMaker, + GraphSendUERecvInferShapeFunctor); +REGISTER_OPERATOR(graph_send_ue_recv_grad, ops::GraphSendUERecvGradOp); diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 2ebaee610f3bd..b99d6de5dbbb4 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -419,6 +419,55 @@ CUDA_ATOMIC_WRAPPER(Max, double) { return __longlong_as_double(old); } +#ifdef PADDLE_CUDA_FP16 +inline static __device__ uint32_t max_to_low_half(uint32_t val, float x) { + float16 low_half; + // The float16 in lower 16bits + low_half.x = static_cast(val & 0xFFFFu); + low_half = static_cast(max(static_cast(low_half), x)); + return (val & 0xFFFF0000u) | low_half.x; +} + +inline static __device__ uint32_t max_to_high_half(uint32_t val, float x) { + float16 high_half; + // The float16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(max(static_cast(high_half), x)); + return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); +} + +CUDA_ATOMIC_WRAPPER(Max, float16) { + if (*address >= val) { + return *address; + } + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // The float16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, max_to_low_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // The float16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, max_to_high_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old >> 16; + return ret; + } +} +#endif + // For atomicMin USE_CUDA_ATOMIC(Min, int); USE_CUDA_ATOMIC(Min, unsigned int); @@ -503,5 +552,54 @@ CUDA_ATOMIC_WRAPPER(Min, double) { return __longlong_as_double(old); } +#ifdef PADDLE_CUDA_FP16 +inline static __device__ uint32_t min_to_low_half(uint32_t val, float x) { + float16 low_half; + // The float16 in lower 16bits + low_half.x = static_cast(val & 0xFFFFu); + low_half = static_cast(min(static_cast(low_half), x)); + return (val & 0xFFFF0000u) | low_half.x; +} + +inline static __device__ uint32_t min_to_high_half(uint32_t val, float x) { + float16 high_half; + // The float16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(min(static_cast(high_half), x)); + return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); +} + +CUDA_ATOMIC_WRAPPER(Min, float16) { + if (*address <= val) { + return *address; + } + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // The float16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, min_to_low_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // The float16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, min_to_high_half(assumed, val_f)); + } while (old != assumed); + float16 ret; + ret.x = old >> 16; + return ret; + } +} +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 588e5521e6070..ba0f872cb7449 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -226,6 +226,7 @@ std::map> op_ins_map = { "Mean3", "Var3"}}, {"graph_send_recv", {"X", "Src_index", "Dst_index", "Out_size"}}, + {"graph_send_ue_recv", {"X", "Y", "Src_index", "Dst_index", "Out_size"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index d58acfd77e203..edd5699f6d1c3 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1082,7 +1082,7 @@ func : generate_proposals_v2 - api : graph_send_recv - args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) + args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) output : Tensor(out), Tensor(dst_count) infer_meta : func : GraphSendRecvInferMeta @@ -1092,6 +1092,17 @@ intermediate : dst_count backward : graph_send_recv_grad +- api : graph_send_ue_recv + args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) + output : Tensor(out), Tensor(dst_count) + infer_meta : + func : GraphSendUERecvInferMeta + kernel : + func : graph_send_ue_recv + data_type : x + intermediate : dst_count + backward : graph_send_ue_recv_grad + - api : greater_equal args : (Tensor x, Tensor y, int axis = -1) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index fdf2321ea38e1..39360232e0e0a 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -941,8 +941,8 @@ func : gelu_grad - backward_api : graph_send_recv_grad - forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM") + forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM") output : Tensor(x_grad) infer_meta : func : GeneralUnaryGradInferMeta @@ -952,6 +952,18 @@ data_type : out_grad optional: out, dst_count +- backward_api : graph_send_ue_recv_grad + forward : graph_send_ue_recv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op, str reduce_op, IntArray out_size) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str message_op, str reduce_op) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : graph_send_ue_recv_grad + data_type : out_grad + optional: out, dst_count + # grid sample - backward_api : grid_sample_grad forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6e4f2dce35f96..7ccd52bb6ff39 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/concat_funcs.h" namespace phi { @@ -2598,6 +2599,94 @@ void Yolov3LossInferMeta(const MetaTensor& x, gt_match_mask->set_dtype(x.dtype()); } +void GraphSendUERecvInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + const IntArray& out_size, + MetaTensor* out, + MetaTensor* dst_count) { + auto src_index_dims = src_index.dims(); + if (src_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(src_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Src_index should be 1 when it " + "is 2D, but we get %d", + src_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + src_index_dims.size(), + 1, + phi::errors::InvalidArgument( + "The Src_index should be 1D, when it is not 2D, but we get %d", + src_index_dims.size())); + } + + auto dst_index_dims = dst_index.dims(); + if (dst_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(dst_index_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dim of Dst_index should be 1 when it " + "is 2D, but we get %d", + dst_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dst_index_dims.size(), + 1, + phi::errors::InvalidArgument("The Dst_index should be 1D, " + "when it is not 2D, but we get %d", + dst_index_dims.size())); + } + + PADDLE_ENFORCE_EQ(src_index_dims[0], + dst_index_dims[0], + phi::errors::InvalidArgument( + "Src_index and Dst_index should have the same shape.")); + + auto y_dims = y.dims(); + PADDLE_ENFORCE_EQ( + y_dims[0], + src_index_dims[0], + phi::errors::InvalidArgument( + "Expect Input Y to have size %d as Src_index on the first dimension, " + "but we get %d", + src_index_dims[0], + y_dims[0])); + + auto x_dims = x.dims(); + if (reduce_op == "MEAN") { + dst_count->set_dims({-1}); + dst_count->set_dtype(DataType::INT32); + } + + // Infer out's shape according to x and e(need broadcasting condition) + out->set_dtype(x.dtype()); + auto x_dims1 = phi::vectorize(x_dims); + auto y_dims1 = phi::vectorize(y_dims); + std::vector x_dims2(x_dims1.begin() + 1, x_dims1.end()); + std::vector y_dims2(y_dims1.begin() + 1, y_dims1.end()); + + int max_dim = std::max(x_dims2.size(), y_dims2.size()); + int axis = std::abs(static_cast(x_dims2.size() - y_dims2.size())); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + // Only need to broadcast dimensions other than the 0th dimension. + phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2), + phi::make_ddim(y_dims2), + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out_dims_array.insert(out_dims_array.begin(), -1); + out->set_dims(phi::make_ddim(out_dims_array)); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 472d665050bde..660121b844d10 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/meta_tensor.h" namespace phi { @@ -465,4 +466,14 @@ void Yolov3LossInferMeta(const MetaTensor& x, MetaTensor* objectness_mask, MetaTensor* gt_match_mask); +void GraphSendUERecvInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& src_index, + const MetaTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + const IntArray& out_size, + MetaTensor* out, + MetaTensor* dst_count); + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index a919a955a541a..342c9e4602309 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -411,7 +411,7 @@ void InstanceNormInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count) { @@ -460,7 +460,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, out->set_dims(phi::make_ddim(dims_)); out->set_dtype(x.dtype()); - if (pool_type == "MEAN") { + if (reduce_op == "MEAN") { dst_count->set_dims({-1}); dst_count->set_dtype(DataType::INT32); } diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 466bd3df5de2d..5314b8f45affe 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -75,7 +75,7 @@ void InstanceNormInferMeta(const MetaTensor& x, void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc index ad04bd258e141..d4131a1ffb5e3 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc @@ -29,10 +29,10 @@ void GraphSendRecvCpuGradLoop(const int& index_size, const DenseTensor& src, const DenseTensor& input, DenseTensor* dst, - const std::string& pool_type, + const std::string& reduce_op, const int* dst_count = nullptr, const DenseTensor* output = nullptr) { - if (pool_type == "SUM") { + if (reduce_op == "SUM") { Functor functor; for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; @@ -40,7 +40,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, ElementwiseInnerOperation( src, dst, src_idx, dst_idx, false, functor); } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; @@ -50,7 +50,7 @@ void GraphSendRecvCpuGradLoop(const int& index_size, auto eigen_dst = phi::EigenVector::Flatten(dst_slice); eigen_dst += (eigen_src / static_cast(dst_count[src_idx])); } - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { for (int i = 0; i < index_size; ++i) { const IndexT& forward_src_idx = d_index[i]; const IndexT& forward_dst_idx = s_index[i]; @@ -75,7 +75,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, const DenseTensor* out = nullptr) { @@ -94,15 +94,15 @@ void GraphSendRecvGradOpKernelLaunchHelper( const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvCpuGradLoop>( - index_size, d_index, s_index, out_grad, x, x_grad, pool_type); - } else if (pool_type == "MEAN") { + index_size, d_index, s_index, out_grad, x, x_grad, reduce_op); + } else if (reduce_op == "MEAN") { const int* s_count = dst_count->data(); // Functor not used here. GraphSendRecvCpuGradLoop>( - index_size, d_index, s_index, out_grad, x, x_grad, pool_type, s_count); - } else if (pool_type == "MIN" || pool_type == "MAX") { + index_size, d_index, s_index, out_grad, x, x_grad, reduce_op, s_count); + } else if (reduce_op == "MIN" || reduce_op == "MAX") { // Functor not used here. GraphSendRecvCpuGradLoop>(index_size, d_index, @@ -110,7 +110,7 @@ void GraphSendRecvGradOpKernelLaunchHelper( out_grad, x, x_grad, - pool_type, + reduce_op, nullptr, out); } @@ -124,7 +124,7 @@ void GraphSendRecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { @@ -134,7 +134,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); @@ -145,7 +145,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index d4b9c8c60e3f8..7985a65a20053 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -32,17 +32,17 @@ void GraphSendRecvCpuLoop(const int& input_size, const IndexT* d_index, const DenseTensor& src, DenseTensor* dst, - const std::string& pool_type, + const std::string& reduce_op, int* dst_count = nullptr) { Functor functor; - if (pool_type == "SUM") { + if (reduce_op == "SUM") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; ElementwiseInnerOperation( src, dst, src_idx, dst_idx, false, functor); } - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; const IndexT& dst_idx = d_index[i]; @@ -59,7 +59,7 @@ void GraphSendRecvCpuLoop(const int& input_size, auto eigen_dst = phi::EigenVector::Flatten(dst_slice); eigen_dst = eigen_dst / static_cast(*(dst_count + i)); } - } else if (pool_type == "MIN" || pool_type == "MAX") { + } else if (reduce_op == "MIN" || reduce_op == "MAX") { std::set existed_dst; for (int i = 0; i < index_size; ++i) { const IndexT& src_idx = s_index[i]; @@ -82,7 +82,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -117,16 +117,16 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MIN") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MIN") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MAX") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MAX") { GraphSendRecvCpuLoop>( - src_dims[0], index_size, s_index, d_index, x, out, pool_type); - } else if (pool_type == "MEAN") { + src_dims[0], index_size, s_index, d_index, x, out, reduce_op); + } else if (reduce_op == "MEAN") { int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; dst_count->Resize({input_size}); ctx.template Alloc(dst_count); @@ -138,7 +138,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, d_index, x, out, - pool_type, + reduce_op, p_dst_count); } } @@ -148,7 +148,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { @@ -159,7 +159,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); @@ -168,7 +168,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h new file mode 100644 index 0000000000000..7647415d8e7cb --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +struct GraphAddFunctor { + inline T operator()(const T a, const T b) const { return a + b; } +}; + +template +struct GraphMulFunctor { + inline T operator()(const T a, const T b) const { return a * b; } +}; + +template +struct GraphMaxFunctor { + inline T operator()(const T a, const T b) const { return a < b ? b : a; } +}; + +template +struct GraphMinFunctor { + inline T operator()(const T a, const T b) const { return a < b ? a : b; } +}; + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc new file mode 100644 index 0000000000000..95fdc6ff0a9cc --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc @@ -0,0 +1,499 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h" + +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +void CalculateXGrad(const Context& ctx, + const T* out_grad, + const T* x_data, + const T* e_data, + const phi::DDim& out_grad_dims, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + T* x_grad, + const DenseTensor& out_grad_tensor, + DenseTensor* x_grad_tensor, + const DenseTensor* dst_count = nullptr, + const DenseTensor* out = nullptr) { + std::vector reduce_idx; + bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx); + + if (reduce_op == "SUM") { + if (message_op == "ADD") { + GraphSendRecvSumFunctor sum_functor; + if (!reduce) { + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + ElementwiseInnerOperation>( + out_grad_tensor, x_grad_tensor, src, dst, false, sum_functor); + } + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + ElementwiseInnerOperation>( + out_grad_tensor, &x_grad_v2, src, dst, false, sum_functor); + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } else if (message_op == "MUL") { + const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims); + if (!reduce) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + T* x_grad_off = x_grad + dst * bcast.out_len; + const T* out_grad_off = out_grad + src * bcast.l_len; + const T* e_off = e_data + i * bcast.r_len; + for (int j = 0; j < bcast.out_len; j++) { + int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = out_grad_off[o_add] * e_off[e_add]; + if (val != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += val; + } + } + } + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + T* x_grad_off = x_grad_v2_data + dst * bcast.out_len; + const T* out_grad_off = out_grad + src * bcast.l_len; + const T* e_off = e_data + i * bcast.r_len; + for (int j = 0; j < bcast.out_len; j++) { + int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = out_grad_off[o_add] * e_off[e_add]; + if (val != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += val; + } + } + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } + } else if (reduce_op == "MEAN") { + const int* s_count = dst_count->data(); + if (message_op == "ADD") { + if (!reduce) { + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + auto out_grad_slice = out_grad_tensor.Slice(src, src + 1); + auto x_grad_slice = x_grad_tensor->Slice(dst, dst + 1); + auto eigen_out_grad = phi::EigenVector::Flatten(out_grad_slice); + auto eigen_x_grad = phi::EigenVector::Flatten(x_grad_slice); + eigen_x_grad += (eigen_out_grad / static_cast(s_count[src])); + } + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + auto out_grad_slice = out_grad_tensor.Slice(src, src + 1); + auto x_grad_slice = x_grad_v2.Slice(dst, dst + 1); + auto eigen_out_grad = phi::EigenVector::Flatten(out_grad_slice); + auto eigen_x_grad = phi::EigenVector::Flatten(x_grad_slice); + eigen_x_grad += (eigen_out_grad / static_cast(s_count[src])); + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } else if (message_op == "MUL") { + const auto& bcast = phi::CalcBCastInfo(out_grad_dims, e_dims); + if (!reduce) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + const T* out_grad_off = out_grad + src * bcast.l_len; + const T* e_off = e_data + i * bcast.r_len; + T* x_grad_off = x_grad + dst * bcast.out_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = out_grad_off[o_add] * e_off[e_add]; +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += (val / s_count[src]); + } + } + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + const T* out_grad_off = out_grad + src * bcast.l_len; + const T* e_off = e_data + i * bcast.r_len; + T* x_grad_off = x_grad_v2_data + dst * bcast.out_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t o_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = out_grad_off[o_add] * e_off[e_add]; +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + x_grad_off[j] += (val / s_count[src]); + } + } + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); + } + } + } +} + +template +void CalculateEGrad(const T* out_grad_data, + const T* x_data, + const T* e_data, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + T* e_grad, + const DenseTensor* dst_count = nullptr) { + const auto& bcast = phi::CalcBCastInfo(x_dims, e_dims); + if (reduce_op == "SUM") { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + const T* x_off = x_data + src * bcast.l_len; + const T* out_grad_off = out_grad_data + dst * bcast.out_len; + T* e_grad_off = e_grad + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + if (message_op == "ADD") { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + e_grad_off[e_add] += out_grad_off[j]; + } else if (message_op == "MUL") { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + e_grad_off[e_add] += (out_grad_off[j] * x_off[x_add]); + } + } + } + } else if (reduce_op == "MEAN") { + const int* s_count = dst_count->data(); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + const T* x_off = x_data + src * bcast.l_len; + const T* out_grad_off = out_grad_data + dst * bcast.out_len; + T* e_grad_off = e_grad + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + if (message_op == "ADD") { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + e_grad_off[e_add] += (out_grad_off[j] / s_count[dst]); + } else if (message_op == "MUL") { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + e_grad_off[e_add] += (out_grad_off[j] * x_off[x_add] / s_count[dst]); + } + } + } + } +} + +template +void CalculateXEGradForMinMax(const T* out_grad, + const T* x_data, + const T* e_data, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + T* x_grad, + T* e_grad, + const DenseTensor* out = nullptr) { + const T* out_data = out->data(); + const auto& bcast = phi::CalcBCastInfo(x_dims, e_dims); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = s_index[i]; + IndexT dst = d_index[i]; + const T* x_off = x_data + dst * bcast.l_len; + const T* e_off = e_data + i * bcast.r_len; + const T* out_off = out_data + src * bcast.out_len; + const T* out_grad_off = out_grad + src * bcast.out_len; + T* x_grad_off = x_grad + dst * bcast.l_len; + T* e_grad_off = e_grad + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j; + if (message_op == "ADD") { + T val = x_off[x_add] + e_off[e_add]; +#ifdef PADDLE_WITH_MKLML +#pragma omp critical +#endif + x_grad_off[x_add] += (out_grad_off[j] * (val == out_off[j])); + e_grad_off[e_add] += (out_grad_off[j] * (val == out_off[j])); + } else if (message_op == "MUL") { + T val = x_off[x_add] * e_off[e_add]; +#ifdef PADDLE_WITH_MKLML +#pragma omp critical +#endif + x_grad_off[x_add] += + (out_grad_off[j] * (val == out_off[j]) * e_off[e_add]); + e_grad_off[e_add] += + (out_grad_off[j] * (val == out_off[j]) * x_off[x_add]); + } + } + } +} + +template +void GraphSendUERecvGradOpKernelLaunchHelper( + const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + DenseTensor* x_grad, + DenseTensor* y_grad, + const DenseTensor* dst_count = nullptr, + const DenseTensor* out = nullptr) { + const int& index_size = dst_index.dims()[0]; + + ctx.template Alloc(x_grad); + T* x_grad_data = x_grad->data(); + ctx.template Alloc(y_grad); + T* y_grad_data = y_grad->data(); + const auto& x_dims = x.dims(); + const auto& y_dims = y.dims(); + int64_t memset_size_x = 1, memset_size_y = 1; + int64_t slice_size = 1; + for (int i = 0; i < x_dims.size(); i++) { + memset_size_x *= x_dims[i]; + if (i > 0) slice_size *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); i++) { + memset_size_y *= y_dims[i]; + } + const size_t& memset_bytes_x = memset_size_x * sizeof(T); + const size_t& memset_bytes_y = memset_size_y * sizeof(T); + memset(x_grad_data, 0, memset_bytes_x); + memset(y_grad_data, 0, memset_bytes_y); + + if (index_size == 0) return; + + const T* out_grad_data = out_grad.data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + + if (reduce_op == "SUM" || reduce_op == "MEAN") { + CalculateXGrad(ctx, + out_grad_data, + x_data, + y_data, + out_grad.dims(), + x_dims, + y_dims, + d_index, + s_index, + message_op, + reduce_op, + index_size, + x_grad_data, + out_grad, + x_grad, + dst_count, + out); + CalculateEGrad(out_grad_data, + x_data, + y_data, + x_dims, + y_dims, + s_index, + d_index, + message_op, + reduce_op, + index_size, + y_grad_data, + dst_count); + } else if (reduce_op == "MIN" || reduce_op == "MAX") { + CalculateXEGradForMinMax(out_grad_data, + x_data, + y_data, + x_dims, + y_dims, + d_index, + s_index, + message_op, + reduce_op, + index_size, + x_grad_data, + y_grad_data, + out); + } +} + +template +void GraphSendUERecvGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const paddle::optional& out, + const paddle::optional& dst_count, + const DenseTensor& out_grad, + const std::string& message_op, + const std::string& reduce_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUERecvGradOpKernelLaunchHelper( + ctx, + out_grad, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + x_grad, + y_grad, + dst_count.get_ptr(), + out.get_ptr()); + } else if (index_type == phi::DataType::INT64) { + GraphSendUERecvGradOpKernelLaunchHelper( + ctx, + out_grad, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + x_grad, + y_grad, + dst_count.get_ptr(), + out.get_ptr()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_ue_recv_grad, + CPU, + ALL_LAYOUT, + phi::GraphSendUERecvGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc new file mode 100644 index 0000000000000..74fca002294db --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc @@ -0,0 +1,293 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h" + +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" + +namespace phi { + +template +void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast, + const T* x_data, + const T* y_data, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + int64_t index_size, + ComputeFunctor cfunctor) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = src_indices[i]; + IndexT dst = dst_indices[i]; + T* out_off = output + dst * bcast.out_len; + const T* x_off = x_data + src * bcast.l_len; + const T* y_off = y_data + i * bcast.r_len; + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = cfunctor(x_off[x_add], y_off[y_add]); + if (val != 0) { +#ifdef PADDLE_WITH_MKLML +#pragma omp atomic +#endif + out_off[j] += val; + } + } + } +} + +template +void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast, + const T* x_data, + const T* y_data, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + int64_t index_size, + ComputeFunctor cfunctor, + CmpFunctor pfunctor) { + std::set existed_dst; +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < index_size; i++) { + IndexT src = src_indices[i]; + IndexT dst = dst_indices[i]; + T* out_off = output + dst * bcast.out_len; + const T* x_off = x_data + src * bcast.l_len; + const T* y_off = y_data + i * bcast.r_len; + bool in_set = existed_dst.find(dst) != existed_dst.end(); + for (int64_t j = 0; j < bcast.out_len; j++) { + int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j; + int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j; + T val = cfunctor(x_off[x_add], y_off[y_add]); +#ifdef PADDLE_WITH_MKLML +#pragma omp critical +#endif + if (!in_set) { + out_off[j] = val; + } else { + out_off[j] = pfunctor(out_off[j], val); + } + } +#ifdef PADDLE_WITH_MKLML +#pragma omp critical +#endif + if (!in_set) { + existed_dst.emplace(dst); + } + } +} + +template +void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t out_size, + DenseTensor* out, + DenseTensor* dst_count = nullptr) { + const int& index_size = src_index.dims()[0]; + auto out_dims = out->dims(); + int64_t memset_size = 1; + std::vector dims_ = phi::vectorize(out_dims); + if (out_size <= 0) { + dims_[0] = x.dims()[0]; + } else { + dims_[0] = out_size; + } + out->Resize(phi::make_ddim(dims_)); + for (size_t i = 0; i < dims_.size(); i++) { + memset_size *= dims_[i]; + } + + ctx.template Alloc(out); + T* out_data = out->data(); + const size_t& memset_bytes = memset_size * sizeof(T); + memset(out_data, 0, memset_bytes); + + if (index_size == 0) return; + const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims()); + const T* x_data = x.data(); + const T* y_data = y.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + if (reduce_op == "SUM" || reduce_op == "MEAN") { + if (message_op == "ADD") { + GraphAddFunctor add_functor; + GraphSendUERecvSumCpuKernel>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + add_functor); + } else if (message_op == "MUL") { + GraphMulFunctor mul_functor; + GraphSendUERecvSumCpuKernel>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + mul_functor); + } + if (reduce_op == "MEAN") { + int64_t input_size = out_size <= 0 ? x.dims()[0] : out_size; + dst_count->Resize({input_size}); + int* dst_count_data = ctx.template Alloc(dst_count); + memset(dst_count_data, 0, input_size * sizeof(int)); + for (int i = 0; i < index_size; i++) { + IndexT dst_idx = d_index[i]; + dst_count_data[dst_idx] += 1; + } + for (int i = 0; i < input_size; i++) { + if (dst_count_data[i] == 0) continue; + auto out_slice = out->Slice(i, i + 1); + auto eigen_out = phi::EigenVector::Flatten(out_slice); + eigen_out = eigen_out / static_cast(dst_count_data[i]); + } + } + } else if (reduce_op == "MIN") { + GraphMinFunctor min_functor; + if (message_op == "ADD") { + GraphAddFunctor add_functor; + GraphSendUERecvMinMaxCpuKernel, + GraphMinFunctor>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + add_functor, + min_functor); + } else if (message_op == "MUL") { + GraphMulFunctor mul_functor; + GraphSendUERecvMinMaxCpuKernel, + GraphMinFunctor>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + mul_functor, + min_functor); + } + } else if (reduce_op == "MAX") { + GraphMaxFunctor max_functor; + if (message_op == "ADD") { + GraphAddFunctor add_functor; + GraphSendUERecvMinMaxCpuKernel, + GraphMaxFunctor>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + add_functor, + max_functor); + } else if (message_op == "MUL") { + GraphMulFunctor mul_functor; + GraphSendUERecvMinMaxCpuKernel, + GraphMaxFunctor>(bcast_info, + x_data, + y_data, + s_index, + d_index, + out_data, + index_size, + mul_functor, + max_functor); + } + } +} + +template +void GraphSendUERecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + const IntArray& out_size, + DenseTensor* out, + DenseTensor* dst_count) { + auto index_type = src_index.dtype(); + auto& out_size_data = out_size.GetData(); + if (index_type == phi::DataType::INT32) { + GraphSendUERecvOpKernelLaunchHelper(ctx, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + out_size_data[0], + out, + dst_count); + } else if (index_type == phi::DataType::INT64) { + GraphSendUERecvOpKernelLaunchHelper(ctx, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + out_size_data[0], + out, + dst_count); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_ue_recv, + CPU, + ALL_LAYOUT, + phi::GraphSendUERecvKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h index 4be92ae18629c..e352c50bdc283 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h @@ -119,7 +119,7 @@ __global__ void ManipulateMeanCUDAKernel(T* output, CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { int64_t c_index = i / slice_size; if (*(count + c_index) > 1) { - *(output + i) = *(output + i) / *(count + c_index); + *(output + i) = *(output + i) / static_cast(*(count + c_index)); } } } @@ -140,8 +140,8 @@ __global__ void ManipulateMeanGradCUDAKernel(const T* params, IndexT dst_i = dst_indices[indices_i]; int64_t in_i = src_i * slice_size + slice_i; int64_t out_i = dst_i * slice_size + slice_i; - paddle::platform::CudaAtomicAdd(output + out_i, - *(params + in_i) / dst_count[src_i]); + paddle::platform::CudaAtomicAdd( + output + out_i, *(params + in_i) / static_cast(dst_count[src_i])); } } @@ -164,7 +164,8 @@ __global__ void ManipulateMinMaxGradCUDAKernel(const T* params, int64_t out_i = dst_i * slice_size + slice_i; paddle::platform::CudaAtomicAdd( output + out_i, - *(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i))); + *(params + in_i) * + static_cast(*(ptr_input + out_i) == *(ptr_output + in_i))); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu index e78fb7892ed7d..d058ee63c3d2f 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu @@ -31,7 +31,7 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad, const DenseTensor* dst_count = nullptr, const DenseTensor* out = nullptr) { @@ -73,16 +73,16 @@ void GraphSendRecvGradOpCUDAKernelLaunchHelper( int64_t grid_tmp = (n + block - 1) / block; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; int64_t input_size = src_dims[0]; - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( p_src, d_index, s_index, p_output, index_size, slice_size, functor); - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { const int32_t* s_count = dst_count->data(); ManipulateMeanGradCUDAKernel<<>>( p_src, d_index, s_index, p_output, index_size, slice_size, s_count); - } else if (pool_type == "MAX" || pool_type == "MIN") { + } else if (reduce_op == "MAX" || reduce_op == "MIN") { const T* ptr_input = x.data(); const T* ptr_output = out->data(); ManipulateMinMaxGradCUDAKernel @@ -105,7 +105,7 @@ void GraphSendRecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad) { auto index_type = src_index.dtype(); if (index_type == phi::DataType::INT32) { @@ -115,7 +115,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); @@ -126,7 +126,7 @@ void GraphSendRecvGradKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, x_grad, dst_count.get_ptr(), out.get_ptr()); @@ -142,4 +142,5 @@ PD_REGISTER_KERNEL(graph_send_recv_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 4dc2794d9c949..055d4888e3f56 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -32,7 +32,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, int64_t out_size, DenseTensor* out, DenseTensor* dst_count = nullptr) { @@ -59,19 +59,19 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, ctx.template Alloc(out); T* p_output = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); - if (pool_type == "SUM" || pool_type == "MEAN") { + if (reduce_op == "SUM" || reduce_op == "MEAN") { #ifdef PADDLE_WITH_HIP hipMemset(p_output, 0, memset_bytes); #else cudaMemset(p_output, 0, memset_bytes); #endif - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { thrust::device_ptr p_output_ptr(p_output); thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, std::numeric_limits::lowest()); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { thrust::device_ptr p_output_ptr(p_output); thrust::fill(thrust::device, p_output_ptr, @@ -99,12 +99,12 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, int64_t grid_tmp = (n + block - 1) / block; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; - if (pool_type == "SUM") { + if (reduce_op == "SUM") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - } else if (pool_type == "MAX") { + } else if (reduce_op == "MAX") { GraphSendRecvMaxCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -115,7 +115,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; InputResetMaxCUDAKernel<<>>( p_output, input_size, slice_size); - } else if (pool_type == "MIN") { + } else if (reduce_op == "MIN") { GraphSendRecvMinCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -126,7 +126,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; InputResetMinCUDAKernel<<>>( p_output, input_size, slice_size); - } else if (pool_type == "MEAN") { + } else if (reduce_op == "MEAN") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> <<>>( @@ -158,7 +158,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { @@ -169,7 +169,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); @@ -178,7 +178,7 @@ void GraphSendRecvKernel(const Context& ctx, x, src_index, dst_index, - pool_type, + reduce_op, out_size_data[0], out, dst_count); @@ -194,4 +194,5 @@ PD_REGISTER_KERNEL(graph_send_recv, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h new file mode 100644 index 0000000000000..49b48b5397538 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h @@ -0,0 +1,400 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" + +namespace phi { + +inline void CopyBCastOff(const BroadCastInfo& bcast_info, + thrust::device_vector& l_bcastoff, + thrust::device_vector& r_bcastoff) { + l_bcastoff.resize(bcast_info.out_len); + r_bcastoff.resize(bcast_info.out_len); +#ifdef PADDLE_WITH_HIP + hipMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()), + bcast_info.l_offset.data(), + sizeof(int64_t) * bcast_info.out_len, + hipMemcpyHostToDevice); + hipMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()), + bcast_info.r_offset.data(), + sizeof(int64_t) * bcast_info.out_len, + hipMemcpyHostToDevice); +#else + cudaMemcpy(thrust::raw_pointer_cast(l_bcastoff.data()), + bcast_info.l_offset.data(), + sizeof(int64_t) * bcast_info.out_len, + cudaMemcpyHostToDevice); + cudaMemcpy(thrust::raw_pointer_cast(r_bcastoff.data()), + bcast_info.r_offset.data(), + sizeof(int64_t) * bcast_info.out_len, + cudaMemcpyHostToDevice); +#endif +} + +inline int FindNumThreads(int dim, int max_num_threads) { + PADDLE_ENFORCE_GE(dim, + 0, + phi::errors::PreconditionNotMet( + "Required dim >= 0, but received dim = %d", dim)); + int res = max_num_threads; + if (dim == 0) res = 1; + while (res > dim) { + res = res >> 1; + } + res = res <= 32 ? 32 : res; + return res; +} + +template +struct GraphSendUERecvSumCUDAFunctor { + DEVICE inline void operator()(T* output, T val) { + paddle::platform::CudaAtomicAdd(output, val); + } +}; + +template +struct GraphSendUERecvMaxCUDAFunctor { + DEVICE inline void operator()(T* output, T val) { + paddle::platform::CudaAtomicMax(output, val); + } +}; + +template +struct GraphSendUERecvMinCUDAFunctor { + DEVICE inline void operator()(T* output, T val) { + paddle::platform::CudaAtomicMin(output, val); + } +}; + +template +__global__ void GraphSendUERecvCUDAKernel(const T* x_data, + const T* e_data, + const IndexT* src_indices, + const IndexT* dst_indices, + const int64_t* xbcast_off, + const int64_t* ebcast_off, + T* output, + int64_t index_size, + int64_t x_len, + int64_t e_len, + int64_t out_len, + bool use_bcast, + ComputeFunctor cfunctor, + ReduceFunctor rfunctor) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + src * x_len; + const T* e_off = e_data + ty * e_len; + T* out_off = output + dst * out_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? xbcast_off[tx] : tx; + int64_t e_add = use_bcast ? ebcast_off[tx] : tx; + T val = cfunctor(x_off[x_add], e_off[e_add]); + rfunctor(out_off + tx, val); + tx += stride_x; + } + ty += stride_y; + } +} + +// x_grad: for backward mean with mul. +template +__global__ void ManipulateMeanGradCUDAKernelForMulX(const T* out_grad_data, + const T* e_data, + const IndexT* src_indices, + const IndexT* dst_indices, + const int* dst_count, + const int64_t* l_bcastoff, + const int64_t* r_bcastoff, + T* x_grad, + int64_t index_size, + int64_t l_len, + int64_t r_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* out_grad_off = out_grad_data + src * l_len; + const T* e_off = e_data + ty * r_len; + T* x_grad_off = x_grad + dst * out_len; + while (tx < out_len) { + int64_t o_add = use_bcast ? l_bcastoff[tx] : tx; + int64_t e_add = use_bcast ? r_bcastoff[tx] : tx; + T val = out_grad_off[o_add] * e_off[e_add]; + paddle::platform::CudaAtomicAdd(x_grad_off + tx, + val / static_cast(dst_count[src])); + tx += stride_x; + } + ty += stride_y; + } +} + +// e_grad: backward sum for add. +template +__global__ void ManipulateSumGradCUDAKernelForAddE(const T* out_grad_data, + const IndexT* dst_indices, + const int64_t* r_bcastoff, + T* e_grad, + int64_t index_size, + int64_t r_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + T* e_grad_off = e_grad + ty * r_len; + const T* out_grad_off = out_grad_data + dst * out_len; + while (tx < out_len) { + int64_t e_add = use_bcast ? r_bcastoff[tx] : tx; + paddle::platform::CudaAtomicAdd(e_grad_off + e_add, out_grad_off[tx]); + tx += stride_x; + } + ty += stride_y; + } +} + +// e_grad: backward sum for mul. +template +__global__ void ManipulateSumGradCUDAKernelForMulE(const T* x_data, + const T* out_grad_data, + const IndexT* src_indices, + const IndexT* dst_indices, + const int64_t* l_bcastoff, + const int64_t* r_bcastoff, + T* e_grad, + int64_t index_size, + int64_t l_len, + int64_t r_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + src * l_len; + T* e_grad_off = e_grad + ty * r_len; + const T* out_grad_off = out_grad_data + dst * out_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? l_bcastoff[tx] : tx; + int64_t e_add = use_bcast ? r_bcastoff[tx] : tx; + paddle::platform::CudaAtomicAdd(e_grad_off + e_add, + out_grad_off[tx] * x_off[x_add]); + tx += stride_x; + } + ty += stride_y; + } +} + +// e_grad: backward mean for add +template +__global__ void ManipulateMeanGradCUDAKernelForAddE(const T* out_grad_data, + const IndexT* dst_indices, + const int* dst_count, + const int64_t* r_bcastoff, + T* e_grad, + int64_t index_size, + int64_t r_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + T* e_grad_off = e_grad + ty * r_len; + const T* out_grad_off = out_grad_data + dst * out_len; + while (tx < out_len) { + int64_t e_add = use_bcast ? r_bcastoff[tx] : tx; + paddle::platform::CudaAtomicAdd( + e_grad_off + e_add, + out_grad_off[tx] / static_cast(dst_count[dst])); + tx += stride_x; + } + ty += stride_y; + } +} + +// e_grad: backward mean for mul. +template +__global__ void ManipulateMeanGradCUDAKernelForMulE(const T* x_data, + const T* out_grad_data, + const IndexT* src_indices, + const IndexT* dst_indices, + const int* dst_count, + const int64_t* l_bcastoff, + const int64_t* r_bcastoff, + T* e_grad, + int64_t index_size, + int64_t l_len, + int64_t r_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + src * l_len; + T* e_grad_off = e_grad + ty * r_len; + const T* out_grad_off = out_grad_data + dst * out_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? l_bcastoff[tx] : tx; + int64_t e_add = use_bcast ? r_bcastoff[tx] : tx; + paddle::platform::CudaAtomicAdd( + e_grad_off + e_add, + out_grad_off[tx] * x_off[x_add] / static_cast(dst_count[dst])); + tx += stride_x; + } + ty += stride_y; + } +} + +// x_grad, e_grad: backward min and max for add. +template +__global__ void ManipulateMinMaxGradCUDAKernelForAdd(const T* x_data, + const T* e_data, + const T* out, + const T* out_grad, + const IndexT* src_indices, + const IndexT* dst_indices, + const int64_t* xbcast_off, + const int64_t* ebcast_off, + T* x_grad, + T* e_grad, + int64_t index_size, + int64_t x_len, + int64_t e_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + dst * x_len; + const T* e_off = e_data + ty * e_len; + const T* out_off = out + src * out_len; + const T* out_grad_off = out_grad + src * out_len; + T* x_grad_off = x_grad + dst * x_len; + T* e_grad_off = e_grad + ty * e_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? xbcast_off[tx] : tx; + int64_t e_add = use_bcast ? ebcast_off[tx] : tx; + T val = x_off[x_add] + e_off[e_add]; + paddle::platform::CudaAtomicAdd( + x_grad_off + x_add, + out_grad_off[tx] * static_cast(val == out_off[tx])); + paddle::platform::CudaAtomicAdd( + e_grad_off + e_add, + out_grad_off[tx] * static_cast(val == out_off[tx])); + tx += stride_x; + } + ty += stride_y; + } +} + +// x_grad, e_grad: backward min and max for mul. +template +__global__ void ManipulateMinMaxGradCUDAKernelForMul(const T* x_data, + const T* e_data, + const T* out, + const T* out_grad, + const IndexT* src_indices, + const IndexT* dst_indices, + const int64_t* xbcast_off, + const int64_t* ebcast_off, + T* x_grad, + T* e_grad, + int64_t index_size, + int64_t x_len, + int64_t e_len, + int64_t out_len, + bool use_bcast) { + IndexT ty = blockIdx.y * blockDim.y + threadIdx.y; + const IndexT stride_y = blockDim.y * gridDim.y; + + while (ty < index_size) { + IndexT src = src_indices[ty]; + IndexT dst = dst_indices[ty]; + int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t stride_x = blockDim.x * gridDim.x; + + const T* x_off = x_data + dst * x_len; + const T* e_off = e_data + ty * e_len; + const T* out_off = out + src * out_len; + const T* out_grad_off = out_grad + src * out_len; + T* x_grad_off = x_grad + dst * x_len; + T* e_grad_off = e_grad + ty * e_len; + while (tx < out_len) { + int64_t x_add = use_bcast ? xbcast_off[tx] : tx; + int64_t e_add = use_bcast ? ebcast_off[tx] : tx; + T val = x_off[x_add] * e_off[e_add]; + paddle::platform::CudaAtomicAdd( + x_grad_off + x_add, + out_grad_off[tx] * static_cast(val == out_off[tx]) * e_off[e_add]); + paddle::platform::CudaAtomicAdd( + e_grad_off + e_add, + out_grad_off[tx] * static_cast(val == out_off[tx]) * x_off[x_add]); + tx += stride_x; + } + ty += stride_y; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu new file mode 100644 index 0000000000000..cb3d5591a7be6 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu @@ -0,0 +1,613 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" + +namespace phi { + +template +void CalculateXEGradForMinMax(const Context& ctx, + const T* out_grad, + const T* x_data, + const T* e_data, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + T* x_grad, + T* e_grad, + const DenseTensor* out = nullptr) { + const T* out_data = out->data(); + const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims); + thrust::device_vector l_bcastoff, r_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff); + } + + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid(nbx, nby); + const dim3 block(ntx, nty); + + if (message_op == "ADD") { + ManipulateMinMaxGradCUDAKernelForAdd + <<>>( + x_data, + e_data, + out_data, + out_grad, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad, + e_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } else if (message_op == "MUL") { + ManipulateMinMaxGradCUDAKernelForMul + <<>>( + x_data, + e_data, + out_data, + out_grad, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad, + e_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } +} + +template +void CalculateXGrad(const Context& ctx, + const T* out_grad, + const T* x_data, + const T* e_data, + const phi::DDim& out_grad_dims, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + int64_t slice_size, + T* x_grad, + const DenseTensor& out_grad_tensor, + const DenseTensor* dst_count = nullptr, + const DenseTensor* out = nullptr) { +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int64_t n = slice_size * index_size; + int max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_tmp = (n + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + std::vector reduce_idx; + bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx); + if (reduce_op == "SUM") { + if (message_op == "ADD") { + GraphSendRecvSumCUDAFunctor functor; + if (!reduce) { + GraphSendRecvCUDAKernel> + <<>>(out_grad, + d_index, + s_index, + x_grad, + index_size, + slice_size, + functor); + } else { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims); + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + GraphSendRecvCUDAKernel> + <<>>(out_grad, + d_index, + s_index, + x_grad_v2_data, + index_size, + bcast_info.out_len, + functor); + // Run reduce_sum + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } else if (message_op == "MUL") { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims); + thrust::device_vector l_bcastoff, r_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff); + } + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid_(nbx, nby); + const dim3 block_(ntx, nty); + funcs::MultiplyFunctor mul_functor; + GraphSendUERecvSumCUDAFunctor sum_functor; + if (!reduce) { + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + out_grad, + e_data, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + sum_functor); + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + out_grad, + e_data, + d_index, + s_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad_v2_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + sum_functor); + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } + } else if (reduce_op == "MEAN") { + const int* s_count = dst_count->data(); + if (message_op == "ADD") { + if (!reduce) { + ManipulateMeanGradCUDAKernel + <<>>(out_grad, + d_index, + s_index, + x_grad, + index_size, + slice_size, + s_count); + } else { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims); + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + ManipulateMeanGradCUDAKernel + <<>>(out_grad, + d_index, + s_index, + x_grad_v2_data, + index_size, + bcast_info.out_len, + s_count); + // Run reduce_sum + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } else if (message_op == "MUL") { + const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims); + thrust::device_vector l_bcastoff, r_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff); + } + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid_(nbx, nby); + const dim3 block_(ntx, nty); + if (!reduce) { + ManipulateMeanGradCUDAKernelForMulX + <<>>( + out_grad, + e_data, + d_index, + s_index, + s_count, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } else { + DenseTensor x_grad_v2 = + phi::EmptyLike(ctx, out_grad_tensor); + phi::funcs::SetConstant()(ctx, &x_grad_v2, T(0)); + T* x_grad_v2_data = x_grad_v2.data(); + ManipulateMeanGradCUDAKernelForMulX + <<>>( + out_grad, + e_data, + d_index, + s_index, + s_count, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + x_grad_v2_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + // Run reduce_sum + DenseTensor x_grad_out = phi::Sum( + ctx, + x_grad_v2, + reduce_idx, + paddle::experimental::CppTypeToDataType::Type(), + true); + // TODO(daisiming): Whether use x_grad instead. +#ifdef PADDLE_WITH_HIP + hipMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + hipMemcpyDeviceToDevice); +#else + cudaMemcpy(x_grad, + x_grad_out.data(), + x_grad_out.numel() * sizeof(T), + cudaMemcpyDeviceToDevice); +#endif + } + } + } +} + +template +void CalculateEGrad(const Context& ctx, + const T* out_grad, + const T* x_data, + const T* e_data, + const phi::DDim& x_dims, + const phi::DDim& e_dims, + const IndexT* s_index, + const IndexT* d_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t index_size, + T* e_grad, + const DenseTensor* dst_count = nullptr) { + const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims); + thrust::device_vector l_bcastoff, r_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, l_bcastoff, r_bcastoff); + } + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid(nbx, nby); + const dim3 block(ntx, nty); + if (reduce_op == "SUM") { + if (message_op == "ADD") { + ManipulateSumGradCUDAKernelForAddE + <<>>( + out_grad, + d_index, + thrust::raw_pointer_cast(r_bcastoff.data()), + e_grad, + index_size, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } else if (message_op == "MUL") { + ManipulateSumGradCUDAKernelForMulE + <<>>( + x_data, + out_grad, + s_index, + d_index, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + e_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } + } else if (reduce_op == "MEAN") { + const int* s_count = dst_count->data(); + if (message_op == "ADD") { + ManipulateMeanGradCUDAKernelForAddE + <<>>( + out_grad, + d_index, + s_count, + thrust::raw_pointer_cast(r_bcastoff.data()), + e_grad, + index_size, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } else if (message_op == "MUL") { + ManipulateMeanGradCUDAKernelForMulE + <<>>( + x_data, + out_grad, + s_index, + d_index, + s_count, + thrust::raw_pointer_cast(l_bcastoff.data()), + thrust::raw_pointer_cast(r_bcastoff.data()), + e_grad, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast); + } + } +} + +template +void GraphSendUERecvGradOpCUDAKernelLaunchHelper( + const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& e, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + DenseTensor* x_grad, + DenseTensor* e_grad, + const DenseTensor* dst_count = nullptr, + const DenseTensor* out = nullptr) { + const int& index_size = dst_index.dims()[0]; + + ctx.template Alloc(x_grad); + T* x_grad_data = x_grad->data(); + ctx.template Alloc(e_grad); + T* e_grad_data = e_grad->data(); + const auto& x_dims = x.dims(); + const auto& e_dims = e.dims(); + int64_t memset_size_x = 1, memset_size_e = 1; + int64_t slice_size = 1; + for (int i = 0; i < x_dims.size(); i++) { + memset_size_x *= x_dims[i]; + if (i > 0) slice_size *= x_dims[i]; + } + for (int i = 0; i < e_dims.size(); i++) { + memset_size_e *= e_dims[i]; + } + const size_t& memset_bytes_x = memset_size_x * sizeof(T); + const size_t& memset_bytes_e = memset_size_e * sizeof(T); +#ifdef PADDLE_WITH_HIP + hipMemset(x_grad_data, 0, memset_bytes_x); + hipMemset(e_grad_data, 0, memset_bytes_e); +#else + cudaMemset(x_grad_data, 0, memset_bytes_x); + cudaMemset(e_grad_data, 0, memset_bytes_e); +#endif + + if (index_size == 0) return; + + const T* out_grad_data = out_grad.data(); + const T* x_data = x.data(); + const T* e_data = e.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + + if (reduce_op == "SUM" || reduce_op == "MEAN") { + CalculateXGrad(ctx, + out_grad_data, + x_data, + e_data, + out_grad.dims(), + x_dims, + e_dims, + s_index, + d_index, + message_op, + reduce_op, + index_size, + slice_size, + x_grad_data, + out_grad, + dst_count, + out); + CalculateEGrad(ctx, + out_grad_data, + x_data, + e_data, + x_dims, + e_dims, + s_index, + d_index, + message_op, + reduce_op, + index_size, + e_grad_data, + dst_count); + } else if (reduce_op == "MIN" || reduce_op == "MAX") { + CalculateXEGradForMinMax(ctx, + out_grad_data, + x_data, + e_data, + x_dims, + e_dims, + s_index, + d_index, + message_op, + reduce_op, + index_size, + x_grad_data, + e_grad_data, + out); + } +} + +template +void GraphSendUERecvGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const paddle::optional& out, + const paddle::optional& dst_count, + const DenseTensor& out_grad, + const std::string& message_op, + const std::string& reduce_op, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendUERecvGradOpCUDAKernelLaunchHelper( + ctx, + out_grad, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + x_grad, + y_grad, + dst_count.get_ptr(), + out.get_ptr()); + } else if (index_type == phi::DataType::INT64) { + GraphSendUERecvGradOpCUDAKernelLaunchHelper( + ctx, + out_grad, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + x_grad, + y_grad, + dst_count.get_ptr(), + out.get_ptr()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_ue_recv_grad, + GPU, + ALL_LAYOUT, + phi::GraphSendUERecvGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu new file mode 100644 index 0000000000000..f339387f0bbfc --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu @@ -0,0 +1,333 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_ue_recv_kernel.h" +#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" +#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +template +void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& e, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + int64_t out_size, + DenseTensor* out, + DenseTensor* dst_count = nullptr) { + const int& index_size = src_index.dims()[0]; + auto out_dims = out->dims(); + int64_t memset_size = 1; + std::vector dims_ = phi::vectorize(out_dims); + if (out_size <= 0) { + dims_[0] = x.dims()[0]; + } else { + dims_[0] = out_size; + } + out->Resize(phi::make_ddim(dims_)); + for (size_t i = 0; i < dims_.size(); i++) { + memset_size *= dims_[i]; + } + + ctx.template Alloc(out); + T* out_data = out->data(); + const size_t& memset_bytes = memset_size * sizeof(T); + if (reduce_op == "SUM" || reduce_op == "MEAN") { +#ifdef PADDLE_WITH_HIP + hipMemset(out_data, 0, memset_bytes); +#else + cudaMemset(out_data, 0, memset_bytes); +#endif + } else if (reduce_op == "MAX") { + thrust::device_ptr out_data_ptr(out_data); + thrust::fill(thrust::device, + out_data_ptr, + out_data_ptr + memset_size, + std::numeric_limits::lowest()); + + } else if (reduce_op == "MIN") { + thrust::device_ptr out_data_ptr(out_data); + thrust::fill(thrust::device, + out_data_ptr, + out_data_ptr + memset_size, + std::numeric_limits::max()); + } + + if (index_size == 0) return; + + const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims()); + const T* x_data = x.data(); + const T* e_data = e.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + + thrust::device_vector x_bcastoff, e_bcastoff; + if (bcast_info.use_bcast) { + CopyBCastOff(bcast_info, x_bcastoff, e_bcastoff); + } + + int64_t out_len = bcast_info.out_len; + const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); + const int nty = ctx.GetMaxThreadsPerBlock() / ntx; + const int nbx = (out_len + ntx - 1) / ntx; + const int nby = (index_size + nty - 1) / nty; + const dim3 grid(nbx, nby); + const dim3 block(ntx, nty); + int64_t input_size = x.dims()[0]; +#ifdef PADDLE_WITH_HIP + int block_ = 256; +#else + int block_ = 1024; +#endif + if (reduce_op == "SUM" || reduce_op == "MEAN") { + GraphSendUERecvSumCUDAFunctor sum_functor; + if (message_op == "ADD") { + funcs::AddFunctor add_funtor; + GraphSendUERecvCUDAKernel, + funcs::AddFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + add_funtor, + sum_functor); + } else if (message_op == "MUL") { + funcs::MultiplyFunctor mul_functor; + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + sum_functor); + } + if (reduce_op == "MEAN") { + input_size = out_size <= 0 ? x.dims()[0] : out_size; + dst_count->Resize({input_size}); + ctx.template Alloc(dst_count); + int* dst_count_data = dst_count->data(); +#ifdef PADDLE_WITH_HIP + hipMemset(dst_count_data, 0, input_size * sizeof(int)); +#else + cudaMemset(dst_count_data, 0, input_size * sizeof(int)); +#endif + int64_t grid_count = (index_size + block_ - 1) / block_; + ComputeCountCUDAKernel + <<>>( + dst_count_data, d_index, index_size); + + int64_t grid_mean = (input_size * out_len + block_ - 1) / block_; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_mean_ = + grid_mean < max_grid_dimx ? grid_mean : max_grid_dimx; + ManipulateMeanCUDAKernel<<>>( + out_data, dst_count_data, input_size, out_len); + } + } else if (reduce_op == "MAX") { + GraphSendUERecvMaxCUDAFunctor max_functor; + if (message_op == "ADD") { + funcs::AddFunctor add_funtor; + GraphSendUERecvCUDAKernel, + funcs::AddFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + add_funtor, + max_functor); + } else if (message_op == "MUL") { + funcs::MultiplyFunctor mul_functor; + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + max_functor); + } + if (out_size > 0) { + input_size = out_size; + } + int64_t grid_max = (input_size * out_len + block_ - 1) / block_; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_max_ = grid_max < max_grid_dimx ? grid_max : max_grid_dimx; + InputResetMaxCUDAKernel + <<>>(out_data, input_size, out_len); + } else if (reduce_op == "MIN") { + GraphSendUERecvMinCUDAFunctor min_functor; + if (message_op == "ADD") { + funcs::AddFunctor add_funtor; + GraphSendUERecvCUDAKernel, + funcs::AddFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + add_funtor, + min_functor); + } else if (message_op == "MUL") { + funcs::MultiplyFunctor mul_functor; + GraphSendUERecvCUDAKernel, + funcs::MultiplyFunctor> + <<>>( + x_data, + e_data, + s_index, + d_index, + thrust::raw_pointer_cast(x_bcastoff.data()), + thrust::raw_pointer_cast(e_bcastoff.data()), + out_data, + index_size, + bcast_info.l_len, + bcast_info.r_len, + out_len, + bcast_info.use_bcast, + mul_functor, + min_functor); + } + if (out_size > 0) { + input_size = out_size; + } + int64_t grid_min = (input_size * out_len + block_ - 1) / block_; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_min_ = grid_min < max_grid_dimx ? grid_min : max_grid_dimx; + InputResetMinCUDAKernel + <<>>(out_data, input_size, out_len); + } +} + +template +void GraphSendUERecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + const IntArray& out_size, + DenseTensor* out, + DenseTensor* dst_count) { + auto index_type = src_index.dtype(); + auto& out_size_data = out_size.GetData(); + if (index_type == phi::DataType::INT32) { + GraphSendUERecvOpCUDAKernelLaunchHelper( + ctx, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + out_size_data[0], + out, + dst_count); + } else if (index_type == phi::DataType::INT64) { + GraphSendUERecvOpCUDAKernelLaunchHelper( + ctx, + x, + y, + src_index, + dst_index, + message_op, + reduce_op, + out_size_data[0], + out, + dst_count); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_ue_recv, + GPU, + ALL_LAYOUT, + phi::GraphSendUERecvKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/graph_send_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_recv_grad_kernel.h index 1379e0f542a72..1b618c6fede21 100644 --- a/paddle/phi/kernels/graph_send_recv_grad_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_grad_kernel.h @@ -29,6 +29,6 @@ void GraphSendRecvGradKernel(const Context& ctx, const paddle::optional& out, const paddle::optional& dst_count, const DenseTensor& out_grad, - const std::string& pool_type, + const std::string& reduce_op, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index cd625c92b93ea..023e86064ff51 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -26,7 +26,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& src_index, const DenseTensor& dst_index, - const std::string& pool_type, + const std::string& reduce_op, const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h new file mode 100644 index 0000000000000..74050d126259d --- /dev/null +++ b/paddle/phi/kernels/graph_send_ue_recv_grad_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void GraphSendUERecvGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const paddle::optional& out, + const paddle::optional& dst_count, + const DenseTensor& out_grad, + const std::string& message_op, + const std::string& reduce_op, + DenseTensor* x_grad, + DenseTensor* y_grad); +} // namespace phi diff --git a/paddle/phi/kernels/graph_send_ue_recv_kernel.h b/paddle/phi/kernels/graph_send_ue_recv_kernel.h new file mode 100644 index 0000000000000..a308a78800f3a --- /dev/null +++ b/paddle/phi/kernels/graph_send_ue_recv_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphSendUERecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& message_op, + const std::string& reduce_op, + const IntArray& out_size, + DenseTensor* out, + DenseTensor* dst_count); + +} // namespace phi diff --git a/paddle/phi/kernels/impl/graph_messaage_passing_impl.h b/paddle/phi/kernels/impl/graph_messaage_passing_impl.h new file mode 100644 index 0000000000000..dc1477e77227b --- /dev/null +++ b/paddle/phi/kernels/impl/graph_messaage_passing_impl.h @@ -0,0 +1,140 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright The DGL team. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +struct BroadCastInfo { + bool use_bcast; + // l_offset[i] indicates the start position of tensor lhs that required to + // compute the i-th element in output, so as r_offset[i]. + std::vector l_offset, r_offset; + int64_t l_len, r_len, out_len, reduce_size; +}; + +inline bool UseBroadCast(const phi::DDim& l_dims, const phi::DDim& r_dims) { + if (l_dims.size() != r_dims.size()) { + return true; + } + for (int i = 1; i < l_dims.size(); i++) { + if (l_dims[i] != r_dims[i]) { + return true; + } + } + return false; +} + +inline BroadCastInfo CalcBCastInfo(const phi::DDim& l_dims, + const phi::DDim& r_dims) { + BroadCastInfo binfo; + binfo.use_bcast = UseBroadCast(l_dims, r_dims); + binfo.l_len = 1; + binfo.r_len = 1; + for (int i = 1; i < l_dims.size(); i++) { + binfo.l_len *= l_dims[i]; + } + for (int i = 1; i < r_dims.size(); i++) { + binfo.r_len *= r_dims[i]; + } + // TODO(daisiming): Whether to add dot. + binfo.reduce_size = 1; + if (binfo.use_bcast) { + const int max_dim = std::max(l_dims.size(), r_dims.size()) - 1; + int stride_l = 1, stride_r = 1; + binfo.l_offset.emplace_back(0); + binfo.r_offset.emplace_back(0); + int out_len = 1; + for (int i = 0; i < max_dim; i++) { + // Iterate the axis from back to front. + const int dl = + (l_dims.size() - 1 - i < 1) ? 1 : l_dims[l_dims.size() - 1 - i]; + const int dr = + (r_dims.size() - 1 - i < 1) ? 1 : r_dims[r_dims.size() - 1 - i]; + for (int j = 1; j < std::max(dl, dr); j++) { + for (int k = 0; k < out_len; k++) { + binfo.l_offset.emplace_back(binfo.l_offset[k] + + j * (j < dl) * stride_l); + binfo.r_offset.emplace_back(binfo.r_offset[k] + + j * (j < dr) * stride_r); + } + } + out_len *= std::max(dl, dr); + stride_l *= dl; + stride_r *= dr; + } + binfo.out_len = out_len; + } else { + binfo.out_len = binfo.l_len; + } + return binfo; +} + +inline std::vector InferBroadcastShape(const phi::DDim& x_dims, + const phi::DDim& e_dims, + const std::string& type = "x") { + auto x_dims1 = phi::vectorize(x_dims); + auto e_dims1 = phi::vectorize(e_dims); + std::vector x_dims2(x_dims1.begin() + 1, x_dims1.end()); + std::vector e_dims2(e_dims1.begin() + 1, e_dims1.end()); + int max_dim = std::max(x_dims2.size(), e_dims2.size()); + int axis = std::abs(static_cast(x_dims2.size() - e_dims2.size())); + std::vector x_dims_array(max_dim); + std::vector e_dims_array(max_dim); + std::vector out_dims_array(max_dim); + // Only need to broadcast dimensions other than the 0th dimension. + phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2), + phi::make_ddim(e_dims2), + x_dims_array.data(), + e_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + if (type == "x") { + out_dims_array.insert(out_dims_array.begin(), x_dims[0]); + } else { + out_dims_array.insert(out_dims_array.begin(), e_dims[0]); + } + return out_dims_array; +} + +inline bool ReduceGrad(const phi::DDim& out_grad_dims, + const phi::DDim& x_dims, + std::vector& axis) { + // We must ensure the ndim of out_grad and x are the same. + bool reduce = false; + for (int i = 1; i < out_grad_dims.size(); i++) { + if (out_grad_dims[i] != x_dims[i]) { + reduce = true; + break; + } + } + if (!reduce) return false; + + // Get reduce axis. + for (int i = 1; i < out_grad_dims.size(); i++) { + if (out_grad_dims[i] - x_dims[i] != 0) { + axis.emplace_back(i); + } + } + return true; +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index c8c15619d5d39..0ca1a3fae0230 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -21,12 +21,12 @@ KernelSignature GraphSendRecvOpArgumentMapping( if (ctx.HasInput("Out_size")) { return KernelSignature("graph_send_recv", {"X", "Src_index", "Dst_index"}, - {"pool_type", "Out_size"}, + {"reduce_op", "Out_size"}, {"Out", "Dst_count"}); } else { return KernelSignature("graph_send_recv", {"X", "Src_index", "Dst_index"}, - {"pool_type", "out_size"}, + {"reduce_op", "out_size"}, {"Out", "Dst_count"}); } } @@ -36,7 +36,7 @@ KernelSignature GraphSendRecvGradOpArgumentMapping( return KernelSignature( "graph_send_recv_grad", {"X", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, - {"pool_type"}, + {"reduce_op"}, {"X@GRAD"}); } diff --git a/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc b/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc new file mode 100644 index 0000000000000..0b2ddcc07e1bb --- /dev/null +++ b/paddle/phi/ops/compat/graph_send_ue_recv_sig.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GraphSendUERecvOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Out_size")) { + return KernelSignature("graph_send_ue_recv", + {"X", "Y", "Src_index", "Dst_index"}, + {"message_op", "reduce_op", "Out_size"}, + {"Out", "Dst_count"}); + } else { + return KernelSignature("graph_send_ue_recv", + {"X", "Y", "Src_index", "Dst_index"}, + {"message_op", "reduce_op", "out_size"}, + {"Out", "Dst_count"}); + } +} + +KernelSignature GraphSendUERecvGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "graph_send_ue_recv_grad", + {"X", "Y", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, + {"message_op", "reduce_op"}, + {"X@GRAD", "Y@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(graph_send_ue_recv, + phi::GraphSendUERecvOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(graph_send_ue_recv_grad, + phi::GraphSendUERecvGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 55d265770a4dc..9f899552e6987 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1562,6 +1562,7 @@ set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) +set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index 73c1525519066..81fcf06167e13 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -25,11 +25,11 @@ def graph_send_recv_wrapper(x, src_index, dst_index, - pool_type="sum", + reduce_op="sum", out_size=None, name=None): return paddle.geometric.send_u_recv(x, src_index, dst_index, - pool_type.lower(), out_size, name) + reduce_op.lower(), out_size, name) class TestGraphSendRecvMaxOp(OpTest): @@ -46,7 +46,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MAX'} + self.attrs = {'reduce_op': 'MAX'} out, self.gradient = compute_graph_send_recv_for_min_max( self.inputs, self.attrs) @@ -76,7 +76,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MIN'} + self.attrs = {'reduce_op': 'MIN'} out, self.gradient = compute_graph_send_recv_for_min_max( self.inputs, self.attrs) @@ -107,7 +107,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'SUM'} + self.attrs = {'reduce_op': 'SUM'} out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs) @@ -134,7 +134,7 @@ def setUp(self): self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} - self.attrs = {'pool_type': 'MEAN'} + self.attrs = {'reduce_op': 'MEAN'} out, dst_count = compute_graph_send_recv_for_sum_mean( self.inputs, self.attrs) @@ -153,15 +153,15 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - pool_type = attributes['pool_type'] + reduce_op = attributes['reduce_op'] gather_x = x[src_index] target_shape = list(x.shape) results = np.zeros(target_shape, dtype=x.dtype) - if pool_type == 'SUM': + if reduce_op == 'SUM': for index, s_id in enumerate(dst_index): results[s_id, :] += gather_x[index, :] - elif pool_type == 'MEAN': + elif reduce_op == 'MEAN': count = np.zeros(target_shape[0], dtype=np.int32) for index, s_id in enumerate(dst_index): results[s_id, :] += gather_x[index, :] @@ -169,7 +169,7 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): results = results / count.reshape([-1, 1]) results[np.isnan(results)] = 0 else: - raise ValueError("Invalid pool_type, only SUM, MEAN supported!") + raise ValueError("Invalid reduce_op, only SUM, MEAN supported!") count = np.zeros(target_shape[0], dtype=np.int32) for index, s_id in enumerate(dst_index): @@ -183,7 +183,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): src_index = inputs['Src_index'] dst_index = inputs['Dst_index'] - pool_type = attributes['pool_type'] + reduce_op = attributes['reduce_op'] gather_x = x[src_index] target_shape = list(x.shape) @@ -191,7 +191,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): gradient = np.zeros_like(x) # Calculate forward output - if pool_type == "MAX": + if reduce_op == "MAX": first_set = set() for index, s_id in enumerate(dst_index): if s_id not in first_set: @@ -200,7 +200,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): else: results[s_id, :] = np.maximum(results[s_id, :], gather_x[index, :]) - elif pool_type == "MIN": + elif reduce_op == "MIN": first_set = set() for index, s_id in enumerate(dst_index): if s_id not in first_set: @@ -210,7 +210,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): results[s_id, :] = np.minimum(results[s_id, :], gather_x[index, :]) else: - raise ValueError("Invalid pool_type, only MAX, MIN supported!") + raise ValueError("Invalid reduce_op, only MAX, MIN supported!") # Calculate backward gradient index_size = len(src_index) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py new file mode 100644 index 0000000000000..e8b5bdc7bb8f8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_send_ue_recv_op.py @@ -0,0 +1,981 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 The DGL team. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard + +from op_test import OpTest + + +def get_broadcast_shape(shp1, shp2): + pad_shp1, pad_shp2 = shp1, shp2 + if len(shp1) > len(shp2): + pad_shp2 = [ + 1, + ] * (len(shp1) - len(shp2)) + shp2 + elif len(shp1) < len(shp2): + pad_shp1 = [ + 1, + ] * (len(shp2) - len(shp1)) + shp1 + for d1, d2 in zip(pad_shp1, pad_shp2): + if d1 != d2 and d1 != 1 and d2 != 1: + raise ValueError + rst = [max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)] + return rst + + +class BroadCastInfo(object): + + def __init__(self, x_shape, y_shape): + self.x_shape = x_shape + self.y_shape = y_shape + + self.calculate_bcastinfo() + + def use_bcast(self): + if len(self.x_shape) != len(self.y_shape): + return True + for i in range(1, len(self.x_shape)): + if self.x_shape[i] != self.y_shape[i]: + return True + return False + + def calculate_bcastinfo(self): + lhs_len = 1 + rhs_len = 1 + for i in range(1, len(self.x_shape)): + lhs_len *= self.x_shape[i] + for i in range(1, len(self.y_shape)): + rhs_len *= self.y_shape[i] + use_b = self.use_bcast() + + if use_b: + max_ndim = max(len(self.x_shape), len(self.y_shape)) - 1 + out_len = 1 + stride_l = stride_r = 1 + lhs_offset = [0] + rhs_offset = [0] + for j in range(0, max_ndim): + dl = 1 if (len(self.x_shape) - 1 - j) < 1 \ + else self.x_shape[len(self.x_shape) - 1 - j] + dr = 1 if (len(self.y_shape) - 1 - j) < 1 \ + else self.y_shape[len(self.y_shape) - 1 - j] + for i in range(1, max(dl, dr)): + for k in range(0, out_len): + lhs_offset.append(lhs_offset[k] + i * + (i < dl) * stride_l) + rhs_offset.append(rhs_offset[k] + i * + (i < dr) * stride_r) + + out_len *= max(dl, dr) + stride_l *= dl + stride_r *= dr + else: + out_len = rhs_len + + self.use_broadcast = use_b + self.out_len = out_len + self.lhs_len = lhs_len + self.rhs_len = rhs_len + if use_b: + self.lhs_offset = lhs_offset + self.rhs_offset = rhs_offset + + +def compute_graph_send_ue_recv_for_sum(inputs, attributes): + x = inputs['X'] + y = inputs['Y'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + message_op = attributes['message_op'] + + gather_x = x[src_index] + out_shp = [ + x.shape[0], + ] + get_broadcast_shape(x.shape[1:], y.shape[1:]) + results = np.zeros(out_shp, dtype=x.dtype) + + # Calculate forward output. + if message_op == 'ADD': + x_compute_y = gather_x + y + elif message_op == 'MUL': + x_compute_y = gather_x * y + for index, s_id in enumerate(dst_index): + results[s_id, :] += x_compute_y[index, :] + return results + + +def compute_graph_send_ue_recv_for_mean(inputs, attributes): + x = inputs['X'] + y = inputs['Y'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + message_op = attributes['message_op'] + + gather_x = x[src_index] + out_shp = [ + x.shape[0], + ] + get_broadcast_shape(x.shape[1:], y.shape[1:]) + results = np.zeros(out_shp, dtype=x.dtype) + + # Calculate forward output. + if message_op == 'ADD': + x_compute_y = gather_x + y + elif message_op == 'MUL': + x_compute_y = gather_x * y + count = np.zeros(out_shp[0], dtype=np.int32) + for index, s_id in enumerate(dst_index): + results[s_id, :] += x_compute_y[index, :] + count[s_id] += 1 + count_shape = [out_shp[0]] + count_shape.extend([1] * len(out_shp[1:])) + results = results / count.reshape(count_shape) + results[np.isnan(results)] = 0 + return results, count + + +def compute_graph_send_ue_recv_for_max_min(inputs, attributes): + x = inputs['X'] + y = inputs['Y'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + message_op = attributes['message_op'] + reduce_op = attributes['reduce_op'] + + gather_x = x[src_index] + out_shp = [ + x.shape[0], + ] + get_broadcast_shape(x.shape[1:], y.shape[1:]) + results = np.zeros(out_shp, dtype=x.dtype) + + # Calculate forward output. + if message_op == 'ADD': + x_compute_y = gather_x + y + elif message_op == 'MUL': + x_compute_y = gather_x * y + + first_set = set() + if reduce_op == 'MAX': + for index, s_id in enumerate(dst_index): + if s_id not in first_set: + results[s_id, :] += x_compute_y[index, :] + first_set.add(s_id) + else: + results[s_id, :] = np.maximum(results[s_id, :], + x_compute_y[index, :]) + elif reduce_op == 'MIN': + for index, s_id in enumerate(dst_index): + if s_id not in first_set: + results[s_id, :] += x_compute_y[index, :] + first_set.add(s_id) + else: + results[s_id, :] = np.minimum(results[s_id, :], + x_compute_y[index, :]) + else: + raise ValueError("Invalid reduce_op, only MAX, MIN supported!") + + # Calculate backward gradient. + x_gradient = np.zeros_like(x) + y_gradient = np.zeros_like(y) + bcast_info = BroadCastInfo(x.shape, y.shape) + use_broadcast = bcast_info.use_broadcast + for i in range(len(src_index)): + forward_src_idx = src_index[i] + forward_dst_idx = dst_index[i] + x_off = x[forward_src_idx] + y_off = y[i] + out_off = results[forward_dst_idx] + x_grad_off = x_gradient[forward_src_idx] + y_grad_off = y_gradient[i] + for j in range(bcast_info.out_len): + x_add = bcast_info.lhs_offset[j] if use_broadcast else j + y_add = bcast_info.rhs_offset[j] if use_broadcast else j + if message_op == 'ADD': + if len(x_off.shape) == 1 and len(y_off.shape) == 1: + val = x_off[x_add] + y_off[y_add] + x_grad_off[x_add] += 1 * (val == out_off[j]) + y_grad_off[y_add] += 1 * (val == out_off[j]) + else: + # For simplicity, we only check the situation of x_off.shape=2 + x_add_0 = int(x_add / x_off.shape[1]) + x_add_1 = int(x_add % x_off.shape[1]) + y_add_0 = int(y_add / y_off.shape[1]) + y_add_1 = int(y_add % y_off.shape[1]) + out_add_0 = int(j / out_off.shape[1]) + out_add_1 = int(j % out_off.shape[1]) + val = x_off[x_add_0][x_add_1] + y_off[y_add_0][y_add_1] + x_grad_off[x_add_0][x_add_1] += 1 * ( + val == out_off[out_add_0][out_add_1]) + y_grad_off[y_add_0][y_add_1] += 1 * ( + val == out_off[out_add_0][out_add_1]) + elif message_op == 'MUL': + if len(x_off.shape) == 1 and len(y_off.shape) == 1: + val = x_off[x_add] * y_off[y_add] + x_grad_off[x_add] += 1 * (val == out_off[j]) * y_off[y_add] + y_grad_off[y_add] += 1 * (val == out_off[j]) * x_off[x_add] + else: + # For simplicity, we only check the situation of x_off.shape=2 + x_add_0 = int(x_add / x_off.shape[1]) + x_add_1 = int(x_add % x_off.shape[1]) + y_add_0 = int(y_add / y_off.shape[1]) + y_add_1 = int(y_add % y_off.shape[1]) + out_add_0 = int(j / out_off.shape[1]) + out_add_1 = int(j % out_off.shape[1]) + val = x_off[x_add_0][x_add_1] * y_off[y_add_0][y_add_1] + x_grad_off[x_add_0][x_add_1] += 1 * ( + val == out_off[out_add_0][out_add_1] + ) * y_off[y_add_0][y_add_1] + y_grad_off[y_add_0][y_add_1] += 1 * ( + val == out_off[out_add_0][out_add_1] + ) * x_off[x_add_0][x_add_1] + + gradients = [x_gradient / results.size, y_gradient / results.size] + + return results, gradients + + +def graph_send_ue_recv_wrapper(x, + y, + src_index, + dst_index, + message_op="add", + reduce_op="sum", + out_size=None, + name=None): + return paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + message_op.lower(), reduce_op.lower(), + out_size, name) + + +class TestGraphSendUERecvSumOp(OpTest): + + def setUp(self): + paddle.enable_static() + self.python_api = graph_send_ue_recv_wrapper + self.python_out_sig = ["Out"] + self.op_type = "graph_send_ue_recv" + self.set_config() + self.inputs = { + 'X': self.x, + 'Y': self.y, + 'Src_index': self.src_index, + 'Dst_index': self.dst_index + } + self.attrs = {'message_op': self.message_op, 'reduce_op': 'SUM'} + + out = compute_graph_send_ue_recv_for_sum(self.inputs, self.attrs) + + self.outputs = {'Out': out} + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + + +class TestSumCase1(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestSumCase2(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestSumCase3(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestSumCase4(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestSumCase5(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestSumCase6(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestSumCase7(TestGraphSendUERecvSumOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestGraphSendUERecvMeanOp(OpTest): + + def setUp(self): + paddle.enable_static() + self.python_api = graph_send_ue_recv_wrapper + self.python_out_sig = ["Out"] + self.op_type = "graph_send_ue_recv" + self.set_config() + self.inputs = { + 'X': self.x, + 'Y': self.y, + 'Src_index': self.src_index, + 'Dst_index': self.dst_index + } + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MEAN'} + + out, dst_count = compute_graph_send_ue_recv_for_mean( + self.inputs, self.attrs) + + self.outputs = {'Out': out, 'Dst_count': dst_count} + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + + +class TestMeanCase1(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMeanCase2(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMeanCase3(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMeanCase4(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMeanCase5(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMeanCase6(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMeanCase7(TestGraphSendUERecvMeanOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestGraphSendUERecvMaxOp(OpTest): + + def setUp(self): + paddle.enable_static() + self.python_api = graph_send_ue_recv_wrapper + self.python_out_sig = ["Out"] + self.op_type = "graph_send_ue_recv" + self.set_config() + self.inputs = { + 'X': self.x, + 'Y': self.y, + 'Src_index': self.src_index, + 'Dst_index': self.dst_index + } + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MAX'} + + out, self.gradients = compute_graph_send_ue_recv_for_max_min( + self.inputs, self.attrs) + + self.outputs = {'Out': out} + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X', 'Y'], + 'Out', + user_defined_grads=self.gradients, + check_eager=True) + + +class TestMaxCase1(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMaxCase2(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMaxCase3(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMaxCase4(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMaxCase5(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMaxCase6(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMaxCase7(TestGraphSendUERecvMaxOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestGraphSendUERecvMinOp(OpTest): + + def setUp(self): + paddle.enable_static() + self.python_api = graph_send_ue_recv_wrapper + self.python_out_sig = ["Out"] + self.op_type = "graph_send_ue_recv" + self.set_config() + self.inputs = { + 'X': self.x, + 'Y': self.y, + 'Src_index': self.src_index, + 'Dst_index': self.dst_index + } + self.attrs = {'message_op': self.message_op, 'reduce_op': 'MIN'} + + out, self.gradients = compute_graph_send_ue_recv_for_max_min( + self.inputs, self.attrs) + + self.outputs = {'Out': out} + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X', 'Y'], + 'Out', + user_defined_grads=self.gradients, + check_eager=True) + + +class TestMinCase1(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMinCase2(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMinCase3(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((10, 20)).astype("float64") + self.y = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMinCase4(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMinCase5(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((10, 8, 5)).astype("float64") + self.y = np.random.random((15, 8, 1)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class TestMinCase6(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'ADD' + + +class TestMinCase7(TestGraphSendUERecvMinOp): + + def set_config(self): + self.x = np.random.random((100, 1)).astype("float64") + self.y = np.random.random((15, 20)).astype("float64") + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) + self.src_index = index[:, 0] + self.dst_index = index[:, 1] + self.message_op = 'MUL' + + +class API_GeometricSendUERecvTest(unittest.TestCase): + + def test_compute_all_with_sum(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + y = paddle.ones(shape=[4, 1], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + + res_add = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "add", "sum") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "sub", "sum") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "mul", "sum") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "div", "sum") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [4, 10, 12], [2, 5, 6]], dtype="float32") + np_sub = np.array([[-1, 1, 2], [0, 6, 8], [0, 3, 4]], dtype="float32") + np_mul = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + np_div = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_with_mean(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + y = paddle.ones(shape=[4, 1], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + + res_add = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "add", "mean") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "sub", "mean") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "mul", "mean") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "div", "mean") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [2, 5, 6], [2, 5, 6]], dtype="float32") + np_sub = np.array([[-1, 1, 2], [0, 3, 4], [0, 3, 4]], dtype="float32") + np_mul = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + np_div = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_with_max(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + y = paddle.ones(shape=[4, 1], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + + res_add = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "add", "max") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "sub", "max") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "mul", "max") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "div", "max") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [3, 7, 8], [2, 5, 6]], dtype="float32") + np_sub = np.array([[-1, 1, 2], [1, 5, 6], [0, 3, 4]], dtype="float32") + np_mul = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + np_div = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + + self.assertTrue(np.allclose(np_sub, res_sub, atol=1e-6)) + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_with_max_fp16(self): + paddle.disable_static() + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, + 7]]), + dtype="float16") + y = paddle.ones(shape=[4, 1], dtype="float16") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), + dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), + dtype="int32") + + res_add = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "add", "max") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "sub", "max") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "mul", "max") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "div", "max") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [3, 7, 8], [2, 5, 6]], + dtype="float16") + np_sub = np.array([[-1, 1, 2], [1, 5, 6], [0, 3, 4]], + dtype="float16") + np_mul = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], + dtype="float16") + np_div = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], + dtype="float16") + + self.assertTrue(np.allclose(np_sub, res_sub, atol=1e-6)) + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], + res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_with_min(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + y = paddle.ones(shape=[4, 1], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + + res_add = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "add", "min") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "sub", "min") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "mul", "min") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "div", "min") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [1, 3, 4], [2, 5, 6]], dtype="float32") + np_sub = np.array([[-1, 1, 2], [-1, 1, 2], [0, 3, 4]], dtype="float32") + np_mul = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") + np_div = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") + + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_compute_all_with_min_fp16(self): + paddle.disable_static() + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, + 7]]), + dtype="float16") + y = paddle.ones(shape=[4, 1], dtype="float16") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), + dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), + dtype="int32") + res_add = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "add", "min") + res_sub = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "sub", "min") + res_mul = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "mul", "min") + res_div = paddle.geometric.send_ue_recv(x, y, src_index, + dst_index, "div", "min") + res = [res_add, res_sub, res_mul, res_div] + + np_add = np.array([[1, 3, 4], [1, 3, 4], [2, 5, 6]], + dtype="float16") + np_sub = np.array([[-1, 1, 2], [-1, 1, 2], [0, 3, 4]], + dtype="float16") + np_mul = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], + dtype="float16") + np_div = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], + dtype="float16") + + for np_res, paddle_res in zip([np_add, np_sub, np_mul, np_div], + res): + self.assertTrue( + np.allclose(np_res, paddle_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, paddle_res)) + + def test_reshape_lhs_rhs(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + x = x.reshape(shape=[3, 3, 1]) + y = paddle.ones([4, 1], dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + res_add = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "add", "min") + np_add = np.array([[1, 3, 4], [1, 3, 4], [2, 5, 6]], + dtype="float16").reshape([3, 3, 1]) + self.assertTrue( + np.allclose(np_add, res_add, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_add, res_add)) + + def test_out_size_tensor_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + y = paddle.static.data(name="y", shape=[3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[3], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[3], dtype="int32") + out_size = paddle.static.data(name="out_size", + shape=[1], + dtype="int32") + + res_sum = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, + "mul", "sum", out_size) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]], dtype="float32") + data2 = np.array([1, 2, 3], dtype="float32") + data3 = np.array([0, 0, 1], dtype="int32") + data4 = np.array([0, 1, 1], dtype="int32") + data5 = np.array([2], dtype="int32") + + np_sum = np.array([[0, 2, 3], [3, 16, 21]], dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'y': data2, + 'src': data3, + 'dst': data4, + 'out_size': data5, + }, + fetch_list=[res_sum]) + self.assertTrue( + np.allclose(np_sum, ret[0], atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_sum, ret[0])) + + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_compute_all_with_sum() + self.test_compute_all_with_mean() + self.test_compute_all_with_max() + self.test_compute_all_with_max_fp16() + self.test_compute_all_with_min() + self.test_compute_all_with_min_fp16() + self.test_reshape_lhs_rhs() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/geometric/__init__.py b/python/paddle/geometric/__init__.py index 9e59062a7cc6a..65f9335c287c9 100644 --- a/python/paddle/geometric/__init__.py +++ b/python/paddle/geometric/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. from .message_passing import send_u_recv # noqa: F401 +from .message_passing import send_ue_recv # noqa: F401 __all__ = [ 'send_u_recv', + 'send_ue_recv', ] diff --git a/python/paddle/geometric/message_passing/__init__.py b/python/paddle/geometric/message_passing/__init__.py index d9580e658650a..ea6ad1c99a518 100644 --- a/python/paddle/geometric/message_passing/__init__.py +++ b/python/paddle/geometric/message_passing/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .send_recv import send_u_recv # noqa: F401 +from .send_recv import send_ue_recv # noqa: F401 diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index 87379730a2a60..bfe63f1f04d73 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -19,13 +19,13 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from paddle import _C_ops -from .utils import convert_out_size_to_list, get_out_size_tensor_inputs +from .utils import convert_out_size_to_list, get_out_size_tensor_inputs, reshape_lhs_rhs def send_u_recv(x, src_index, dst_index, - pool_type="sum", + reduce_op="sum", out_size=None, name=None): """ @@ -35,13 +35,13 @@ def send_u_recv(x, This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor - in different pooling types, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. + in different reduce ops, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. .. code-block:: text Given: - X = [[0, 2, 3], + x = [[0, 2, 3], [1, 4, 5], [2, 6, 7]] @@ -49,22 +49,23 @@ def send_u_recv(x, dst_index = [1, 2, 1, 0] - pool_type = "sum" + reduce_op = "sum" out_size = None Then: - Out = [[0, 2, 3], + out = [[0, 2, 3], [2, 8, 10], [1, 4, 5]] Args: x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. + And we support float16 in gpu version. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. - pool_type (str): Different pooling types, including `sum`, `mean`, `max`, `min`. + reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`. Default value is `sum`. out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size is smaller or equal to 0, then this input will not be used. @@ -88,7 +89,7 @@ def send_u_recv(x, indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") src_index = indexes[:, 0] dst_index = indexes[:, 1] - out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") + out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") @@ -96,39 +97,40 @@ def send_u_recv(x, src_index = indexes[:, 0] dst_index = indexes[:, 1] out_size = paddle.max(dst_index) + 1 - out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size) + out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum", out_size=out_size) # Outputs: [[0., 2., 3.], [[2., 8., 10.]]] x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") src_index = indexes[:, 0] dst_index = indexes[:, 1] - out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") + out = paddle.geometric.send_u_recv(x, src_index, dst_index, reduce_op="sum") # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] """ - if pool_type not in ["sum", "mean", "max", "min"]: + if reduce_op not in ["sum", "mean", "max", "min"]: raise ValueError( - "pool_type should be `sum`, `mean`, `max` or `min`, but received %s" - % pool_type) + "reduce_op should be `sum`, `mean`, `max` or `min`, but received %s" + % reduce_op) # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_recv(x, src_index, - dst_index, None, 'pool_type', - pool_type.upper(), 'out_size', + dst_index, None, 'reduce_op', + reduce_op.upper(), 'out_size', out_size) return out if in_dygraph_mode(): out_size = convert_out_size_to_list(out_size) return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, - pool_type.upper(), out_size) + reduce_op.upper(), out_size) - check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), - "graph_send_recv") + check_variable_and_dtype( + x, "X", ("float32", "float64", "int32", "int64", "float16"), + "graph_send_recv") check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), "graph_send_recv") check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), @@ -146,7 +148,7 @@ def send_u_recv(x, stop_gradient=True) inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"pool_type": pool_type.upper()} + attrs = {"reduce_op": reduce_op.upper()} get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size, @@ -160,3 +162,177 @@ def send_u_recv(x, }, attrs=attrs) return out + + +def send_ue_recv(x, + y, + src_index, + dst_index, + message_op="add", + reduce_op="sum", + out_size=None, + name=None): + """ + + Graph Learning message passing api. + + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` + to gather the corresponding data, after computing with `y` in different message ops like add/sub/mul/div, then use `dst_index` to + update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min. + Besides, we can use `out_size` to set necessary output shape. + + .. code-block:: text + + Given: + + x = [[0, 2, 3], + [1, 4, 5], + [2, 6, 7]] + + y = [1, 1, 1] + + src_index = [0, 1, 2, 0] + + dst_index = [1, 2, 1, 0] + + message_op = "add" + + reduce_op = "sum" + + out_size = None + + Then: + + out = [[1, 3, 4], + [4, 10, 12], + [2, 5, 6]] + Args: + x (Tensor): The input node feature tensor, and the available data type is float32, float64, int32, int64. + And we support float16 in gpu version. + y (Tensor): The input edge feature tensor, and the available data type is float32, float64, int32, int64. + And we support float16 in gpu version. + src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. + message_op (str): Different message ops for x and e, including `add`, `sub`, `mul`, `div`. + reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`. + Default value is `sum`. + out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or + out_size is smaller or equal to 0, then this input will not be used. + Otherwise, `out_size` should be equal with or larger than + max(dst_index) + 1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. + If `out_size` is set correctly, then it should have the same shape as `x` except + the 0th dimension. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + y = paddle.to_tensor([1, 1, 1, 1], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum") + # Outputs: [[1., 3., 4.], [4., 10., 12.], [2., 5., 6.]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + y = paddle.to_tensor([1, 1, 1], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out_size = paddle.max(dst_index) + 1 + out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum", out_size=out_size) + # Outputs: [[1., 3., 4.], [[4., 10., 12.]]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + y = paddle.to_tensor([1, 1, 1], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.geometric.send_ue_recv(x, y, src_index, dst_index, message_op="add", reduce_op="sum") + # Outputs: [[1., 3., 4.], [4., 10., 12.], [0., 0., 0.]] + + """ + + if message_op not in ["add", "sub", "mul", "div"]: + raise ValueError( + "message_op should be `add`, `sub`, `mul`, `div`, but received %s" % + message_op) + + if reduce_op not in ["sum", "mean", "max", "min"]: + raise ValueError( + "reduce_op should be `sum`, `mean`, `max` or `min`, but received %s" + % reduce_op) + + x, y = reshape_lhs_rhs(x, y) + + if message_op == 'sub': + message_op = 'add' + y = -y + if message_op == "div": + message_op = 'mul' + y = 1. / y + + # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. + + if _in_legacy_dygraph(): + out_size = convert_out_size_to_list(out_size) + out, tmp = _C_ops.graph_send_ue_recv(x, y, src_index, dst_index, + None, 'message_op', + message_op.upper(), 'reduce_op', + reduce_op.upper(), 'out_size', + out_size) + return out + if in_dygraph_mode(): + out_size = convert_out_size_to_list(out_size) + return _C_ops.final_state_graph_send_ue_recv(x, y, src_index, dst_index, + message_op.upper(), + reduce_op.upper(), + out_size) + + check_variable_and_dtype( + x, "X", ("float32", "float64", "int32", "int64", "float16"), + "graph_send_ue_recv") + check_variable_and_dtype( + y, "Y", ("float32", "float64", "int32", "int64", "float16"), + "graph_send_ue_recv") + check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), + "graph_send_ue_recv") + check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), + "graph_send_ue_recv") + if out_size: + check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), + 'graph_send_ue_recv') + if isinstance(out_size, Variable): + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], + 'graph_send_ue_recv') + + helper = LayerHelper("send_ue_recv", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + dst_count = helper.create_variable_for_type_inference(dtype="int32", + stop_gradient=True) + + inputs = {"X": x, "Y": y, "Src_index": src_index, "Dst_index": dst_index} + attrs = {"message_op": message_op.upper(), "reduce_op": reduce_op.upper()} + get_out_size_tensor_inputs(inputs=inputs, + attrs=attrs, + out_size=out_size, + op_type='graph_send_ue_recv') + + helper.append_op(type="graph_send_ue_recv", + inputs=inputs, + outputs={ + "Out": out, + "Dst_count": dst_count + }, + attrs=attrs) + return out diff --git a/python/paddle/geometric/message_passing/utils.py b/python/paddle/geometric/message_passing/utils.py index 3614f829daf52..51c088522983b 100644 --- a/python/paddle/geometric/message_passing/utils.py +++ b/python/paddle/geometric/message_passing/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import paddle from paddle.fluid.framework import Variable from paddle.fluid.data_feeder import check_dtype, convert_dtype from paddle.fluid.layers.tensor import cast @@ -50,3 +51,35 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): inputs["Out_size"] = out_size else: raise TypeError("Out_size only supports Variable or int.") + + +def reshape_lhs_rhs(x, y): + """ + Expand dims to ensure there will be no broadcasting issues with different + number of dimensions. + """ + if len(x.shape) == 1: + x = paddle.reshape(x, [-1, 1]) + if len(y.shape) == 1: + y = paddle.reshape(y, [-1, 1]) + + x_shape = paddle.shape(x) + y_shape = paddle.shape(y) + if len(x.shape) != len(y.shape): + max_ndims = max(len(x.shape), len(y.shape)) + x_pad_ndims = max_ndims - len(x.shape) + y_pad_ndims = max_ndims - len(y.shape) + new_x_shape = [ + x_shape[0], + ] + [ + 1, + ] * x_pad_ndims + list(x_shape[1:]) + new_y_shape = [ + y_shape[0], + ] + [ + 1, + ] * y_pad_ndims + list(y_shape[1:]) + x = paddle.reshape(x, new_x_shape) + y = paddle.reshape(y, new_y_shape) + + return x, y diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 132a6d4657ca1..4181885d419af 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -69,7 +69,7 @@ def graph_send_recv(x, src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. - pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. + pool_type (str): The pooling types of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size is smaller or equal to 0, then this input will not be used. @@ -123,7 +123,7 @@ def graph_send_recv(x, if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_recv(x, src_index, - dst_index, None, 'pool_type', + dst_index, None, 'reduce_op', pool_type.upper(), 'out_size', out_size) return out @@ -151,7 +151,7 @@ def graph_send_recv(x, stop_gradient=True) inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} - attrs = {"pool_type": pool_type.upper()} + attrs = {"reduce_op": pool_type.upper()} get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size,