Skip to content

Commit

Permalink
[geometric]Add paddle.geometric.send_uv API (PaddlePaddle#44848)
Browse files Browse the repository at this point in the history
* initial commit

* fix op maker bug

* fix mul grad bug

* add unittest

* fix add grad bug, add cpu kernel

* add paddle.geometric.message_passing

* add paddle.geometric.send_uv api, add unittest

* add fp16 judgement

* fix file typo, move compute_type to message_op

* add impl file

* fix unittest timeout time

* add review revise
  • Loading branch information
DesmonDay committed Oct 12, 2022
1 parent 6e1e27f commit b941c68
Show file tree
Hide file tree
Showing 21 changed files with 1,437 additions and 5 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Expand Up @@ -98,6 +98,16 @@
func : erf
backward : erf_grad

- api : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : GraphSendUVInferMeta
kernel :
func : graph_send_uv
data_type : x
backward : graph_send_uv_grad

- api : lgamma
args : (Tensor x)
output : Tensor(out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Expand Up @@ -105,6 +105,17 @@
func : erf_grad
data_type : out_grad

- backward_api : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_uv_grad
data_type : x

- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
70 changes: 70 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -2499,6 +2499,76 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims_array));
}

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}

auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));

// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}

} // namespace phi

PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -431,4 +431,11 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* dst_count);

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc
Expand Up @@ -24,7 +24,7 @@
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc
Expand Up @@ -22,7 +22,7 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"

namespace phi {

Expand Down

0 comments on commit b941c68

Please sign in to comment.