Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[geometric]Add paddle.geometric.send_uv API #44848

Merged
merged 14 commits into from Aug 16, 2022
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Expand Up @@ -135,6 +135,16 @@
func : fft_r2c
backward : fft_r2c_grad

- api : graph_send_uv
args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD")
output : Tensor(out)
infer_meta :
func : GraphSendUVInferMeta
kernel :
func : graph_send_uv
data_type : x
backward : graph_send_uv_grad

- api : lgamma
args : (Tensor x)
output : Tensor(out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Expand Up @@ -147,6 +147,17 @@
data_type: out_grad
no_need_buffer: x

- backward_api : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : graph_send_uv_grad
data_type : x

- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
70 changes: 70 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -2687,6 +2687,76 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims_array));
}

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out) {
auto src_index_dims = src_index.dims();
if (src_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(src_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Src_index should be 1 when it "
"is 2D, but we get %d",
src_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
src_index_dims.size(),
1,
phi::errors::InvalidArgument(
"The Src_index should be 1D, when it is not 2D, but we get %d",
src_index_dims.size()));
}

auto dst_index_dims = dst_index.dims();
if (dst_index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(dst_index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of Dst_index should be 1 when it "
"is 2D, but we get %d",
dst_index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
dst_index_dims.size(),
1,
phi::errors::InvalidArgument("The Dst_index should be 1D, "
"when it is not 2D, but we get %d",
dst_index_dims.size()));
}

PADDLE_ENFORCE_EQ(src_index_dims[0],
dst_index_dims[0],
phi::errors::InvalidArgument(
"Src_index and Dst_index should have the same shape."));

// Infer out's shape according to x and y(need broadcasting condition)
out->set_dtype(x.dtype());
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims1 = phi::vectorize<int>(x_dims);
auto y_dims1 = phi::vectorize<int>(y_dims);
std::vector<int> x_dims2(x_dims1.begin() + 1, x_dims1.end());
std::vector<int> y_dims2(y_dims1.begin() + 1, y_dims1.end());
int max_dim = std::max(x_dims2.size(), y_dims2.size());
int axis = std::abs(static_cast<int>(x_dims2.size() - y_dims2.size()));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
// Only need to broadcast dimensions other than the 0th dimension.
phi::funcs::GetBroadcastDimsArrays(phi::make_ddim(x_dims2),
phi::make_ddim(y_dims2),
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
out_dims_array.insert(out_dims_array.begin(), src_index_dims[0]);
out->set_dims(phi::make_ddim(out_dims_array));
}

} // namespace phi

PD_REGISTER_INFER_META_FN(batch_norm, phi::BatchNormInferMeta);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/multiary.h
Expand Up @@ -476,4 +476,11 @@ void GraphSendUERecvInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* dst_count);

void GraphSendUVInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& message_op,
MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc
Expand Up @@ -24,7 +24,7 @@
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc
Expand Up @@ -22,7 +22,7 @@
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_messaage_passing_impl.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"

namespace phi {

Expand Down
254 changes: 254 additions & 0 deletions paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc
@@ -0,0 +1,254 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/graph_send_uv_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

template <typename Context, typename T, typename IndexT>
void CalculateGrad(const Context& ctx,
const T* out_grad,
const IndexT* s_index,
const IndexT* d_index,
const phi::DDim& out_grad_dims,
const phi::DDim& x_grad_dims,
const std::string& message_op,
int64_t index_size,
int64_t slice_size,
T* x_grad,
const DenseTensor& out_grad_tensor,
const DenseTensor& y) {
std::vector<int64_t> reduce_idx;
bool reduce = ReduceGrad(out_grad_dims, x_grad_dims, reduce_idx);

if (message_op == "ADD") {
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT dst = d_index[i];
T* x_grad_off = x_grad + dst * slice_size;
const T* out_grad_off = out_grad + i * slice_size;
for (int64_t j = 0; j < slice_size; j++) {
if (out_grad_off[j] != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有个讨论,这里是不是也可以直接加上out_grad_off,不管梯度是否为0;因为在多线程计算里面,看起来直接加法会比有分支更好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要是因为这里会涉及到一个原子操作,所以加上分支判断可以减少原子操作。

#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += out_grad_off[j];
}
}
}
} else {
const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, x_grad_dims);
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

虽然这里的计算开销相对较小,不过建议是使用emplace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT dst = d_index[i];
T* x_grad_off = x_grad_v2_data + dst * bcast_info.out_len;
const T* out_grad_off = out_grad + i * bcast_info.out_len;
for (int64_t j = 0; j < bcast_info.out_len; j++) {
if (out_grad_off[j] != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += out_grad_off[j];
}
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
} else if (message_op == "MUL") {
const auto& bcast = phi::CalcBCastInfo(y.dims(), out_grad_dims);
const T* y_data = y.data<T>();
if (!reduce) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad + dst * bcast.out_len;
const T* y_off = y_data + src * bcast.l_len;
const T* out_grad_off = out_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = y_off[y_add] * out_grad_off[o_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
} else {
auto out_grad_dims_1 = phi::vectorize<int>(out_grad_dims);
std::vector<int> out_grad_dims_2(out_grad_dims_1.begin() + 1,
out_grad_dims_1.end());
out_grad_dims_2.insert(out_grad_dims_2.begin(), x_grad_dims[0]);
DenseTensor x_grad_v2 = phi::Empty<T, Context>(ctx, out_grad_dims_2);
phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
T* x_grad_v2_data = x_grad_v2.data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < index_size; i++) {
IndexT src = s_index[i];
IndexT dst = d_index[i];
T* x_grad_off = x_grad_v2_data + dst * bcast.out_len;
const T* y_off = y_data + src * bcast.l_len;
const T* out_grad_off = out_grad + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t y_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t o_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = y_off[y_add] * out_grad_off[o_add];
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
#endif
x_grad_off[j] += val;
}
}
}
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
}
}
}

template <typename Context, typename T, typename IndexT>
void GraphSendUVGradOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const int64_t& index_size = dst_index.dims()[0];

ctx.template Alloc<T>(x_grad);
T* x_grad_data = x_grad->data<T>();
ctx.template Alloc<T>(y_grad);
T* y_grad_data = y_grad->data<T>();
const auto& x_grad_dims = x_grad->dims();
const auto& y_grad_dims = y_grad->dims();
int64_t memset_size_x = 1, memset_size_y = 1;
int64_t slice_size_x = 1, slice_size_y = 1;
for (int i = 0; i < x_grad_dims.size(); i++) {
memset_size_x *= x_grad_dims[i];
if (i > 0) slice_size_x *= x_grad_dims[i];
}
for (int i = 0; i < y_grad_dims.size(); i++) {
memset_size_y *= y_grad_dims[i];
if (i > 0) slice_size_y *= y_grad_dims[i];
}
const size_t& memset_bytes_x = memset_size_x * sizeof(T);
const size_t& memset_bytes_y = memset_size_y * sizeof(T);
memset(x_grad_data, 0, memset_bytes_x);
memset(y_grad_data, 0, memset_bytes_y);

if (index_size == 0) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果在前向Op进行处理,这里也可以直接报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const T* out_grad_data = out_grad.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
const auto& out_grad_dims = out_grad.dims();
// Calculate X Grad.
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
d_index,
s_index,
out_grad_dims,
x_grad_dims,
message_op,
index_size,
slice_size_x,
x_grad_data,
out_grad,
y);
// Calcuate Y Grad.
CalculateGrad<Context, T, IndexT>(ctx,
out_grad_data,
s_index,
d_index,
out_grad_dims,
y_grad_dims,
message_op,
index_size,
slice_size_y,
y_grad_data,
out_grad,
x);
}

template <typename T, typename Context>
void GraphSendUVGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const DenseTensor& out_grad,
const std::string& message_op,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
} else if (index_type == phi::DataType::INT64) {
GraphSendUVGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, y, out_grad, src_index, dst_index, message_op, x_grad, y_grad);
}
}

} // namespace phi

PD_REGISTER_KERNEL(graph_send_uv_grad,
CPU,
ALL_LAYOUT,
phi::GraphSendUVGradKernel,
float,
double,
int,
int64_t) {}