Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[geometric]Add paddle.geometric.send_uv API #44848

Merged
merged 14 commits into from Aug 16, 2022
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Expand Up @@ -135,6 +135,16 @@
func : fft_r2c
backward : fft_r2c_grad

- api : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : GraphSendUVInferMeta
kernel :
func : graph_send_uv
data_type : x
backward : graph_send_uv_grad

- api : lgamma
args : (Tensor x)
output : Tensor(out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Expand Up @@ -147,6 +147,17 @@
data_type: out_grad
no_need_buffer: x

- backward_api : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_uv_grad
data_type : x

- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
70 changes: 70 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -2687,6 +2687,76 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims_array));
}

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}

auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));

// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}

} // namespace phi

PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -476,4 +476,11 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* dst_count);

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc
Expand Up @@ -24,7 +24,7 @@
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc
Expand Up @@ -22,7 +22,7 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"

namespace phi {

Expand Down