Skip to content

Commit

Permalink
[geometric]Add paddle.geometric.send_u_recv API (#44580)
Browse files Browse the repository at this point in the history
* change out_size to INTArray

* fix out_size eager bug

* add unittest for out_size tensor

* add deprecated for paddle.incubate.graph_send_recv, add paddle.geometric.send_u_recv and unittests

* fix lowest bug

* fix according review comment

* add default value in yaml

* change api file name

* change name
  • Loading branch information
DesmonDay committed Aug 9, 2022
1 parent 76e0926 commit 34b4355
Show file tree
Hide file tree
Showing 19 changed files with 657 additions and 160 deletions.
12 changes: 8 additions & 4 deletions paddle/fluid/operators/graph_send_recv_op.cc
Expand Up @@ -58,6 +58,10 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor with data type float32, float64, int32, int64.");
AddInput("Src_index", "The source index tensor.");
AddInput("Dst_index", "The destination index tensor.");
AddInput("Out_size",
"(Tensor<int>, optional). The 0th dimension of the output."
"It has a higher priority than Attr(out_size).")
.AsDispensable();
AddOutput("Out", "Output tensor of graph_send_recv op.");
AddOutput("Dst_count",
"Count tensor of Dst_index, mainly for MEAN pool_type.")
Expand All @@ -68,12 +72,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker {
"tensors of Dst_index.")
.SetDefault("SUM")
.InEnum({"SUM", "MEAN", "MIN", "MAX"});
AddAttr<int64_t>(
AddAttr<std::vector<int64_t>>(
"out_size",
"(int64_t, default 0)"
"(vector<int64_t>, default {0})"
"Define the first dimension of Output tensor."
"If set default 0, then the shape of Out is the same with X.")
.SetDefault(0);
"If set default {0}, then the shape of Out is the same with X.")
.SetDefault({0});
AddComment(R"DOC(
Graph Learning Send_Recv combine operator.
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_generator.h
Expand Up @@ -225,6 +225,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Bias3",
"Mean3",
"Var3"}},
{"graph_send_recv", {"X", "Src_index", "Dst_index", "Out_size"}},
};

// NOTE(zhiqiu): Like op_ins_map.
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_api.yaml
Expand Up @@ -1060,7 +1060,7 @@
func : generate_proposals_v2

- api : graph_send_recv
args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0)
args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0})
output : Tensor(out), Tensor(dst_count)
infer_meta :
func : GraphSendRecvInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/legacy_backward.yaml
Expand Up @@ -941,7 +941,7 @@
func : gelu_grad

- backward_api : graph_send_recv_grad
forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0) -> Tensor(out), Tensor(dst_count)
forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count)
args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM")
output : Tensor(x_grad)
infer_meta :
Expand Down
20 changes: 5 additions & 15 deletions paddle/phi/infermeta/ternary.cc
Expand Up @@ -412,7 +412,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count) {
auto src_index_dims = src_index.dims();
Expand Down Expand Up @@ -455,23 +455,13 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
"Src_index and Dst_index should have the same shape."));

auto dims = x.dims();
if (out_size <= 0) {
out->set_dims(dims);
} else {
std::vector<int64_t> dims_ = phi::vectorize(dims);
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->set_dims(phi::make_ddim(dims_));
}
std::vector<int64_t> dims_ = phi::vectorize(dims);
dims_[0] = -1;
out->set_dims(phi::make_ddim(dims_));
out->set_dtype(x.dtype());

if (pool_type == "MEAN") {
if (out_size <= 0) {
dst_count->set_dims({dims[0]});
} else {
dst_count->set_dims({out_size});
}
dst_count->set_dims({-1});
dst_count->set_dtype(DataType::INT32);
}
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/ternary.h
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/meta_tensor.h"

namespace phi {
Expand Down Expand Up @@ -75,7 +76,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x,
const MetaTensor& src_index,
const MetaTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
const IntArray& out_size,
MetaTensor* out,
MetaTensor* dst_count);

Expand Down
43 changes: 33 additions & 10 deletions paddle/phi/kernels/cpu/graph_send_recv_kernel.cc
Expand Up @@ -88,27 +88,35 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];

ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
if (out_size <= 0) {
out->Resize(src_dims);
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
// Set out dim following out_size.
std::vector<int64_t> dims_ = phi::vectorize(src_dims);
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->Resize(phi::make_ddim(dims_));
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}

ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const size_t& memset_bytes = memset_size * sizeof(T);
memset(p_output, 0, memset_bytes);

if (index_size == 0) return;

const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();

if (pool_type == "SUM") {
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
Expand All @@ -119,10 +127,12 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx,
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvMaxFunctor<T>>(
src_dims[0], index_size, s_index, d_index, x, out, pool_type);
} else if (pool_type == "MEAN") {
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
dst_count->Resize({input_size});
ctx.template Alloc<int>(dst_count);
int* p_dst_count = dst_count->data<int>();
memset(p_dst_count, 0, src_dims[0] * sizeof(int));
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(src_dims[0],
memset(p_dst_count, 0, input_size * sizeof(int));
GraphSendRecvCpuLoop<T, IndexT, GraphSendRecvSumFunctor<T>>(input_size,
index_size,
s_index,
d_index,
Expand All @@ -139,16 +149,29 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
GraphSendRecvOpKernelLaunchHelper<Context, T, int32_t>(ctx,
x,
src_index,
dst_index,
pool_type,
out_size_data[0],
out,
dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
GraphSendRecvOpKernelLaunchHelper<Context, T, int64_t>(ctx,
x,
src_index,
dst_index,
pool_type,
out_size_data[0],
out,
dst_count);
}
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/graph_send_recv_funcs.h
Expand Up @@ -81,7 +81,7 @@ __global__ void InputResetMaxCUDAKernel(T* output,
size_t input_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
if (*(output + i) == std::numeric_limits<T>::min()) {
if (*(output + i) == std::numeric_limits<T>::lowest()) {
*(output + i) = 0;
}
}
Expand Down
51 changes: 31 additions & 20 deletions paddle/phi/kernels/gpu/graph_send_recv_kernel.cu
Expand Up @@ -37,20 +37,27 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
DenseTensor* out,
DenseTensor* dst_count = nullptr) {
const int& index_size = src_index.dims()[0];
ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const auto& src_dims = x.dims();
int64_t memset_size = 1;
if (out_size <= 0) {
out->Resize(src_dims);
for (int i = 0; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
} else {
// Set out dim following out_size.
std::vector<int64_t> dims_ = phi::vectorize(out->dims());
if (dims_.size() > 0) {
dims_[0] = out_size;
}
out->Resize(phi::make_ddim(dims_));
memset_size = out_size;
for (int i = 1; i < src_dims.size(); ++i) {
memset_size *= src_dims[i];
}
}
ctx.template Alloc<T>(out);
T* p_output = out->data<T>();
const size_t& memset_bytes = memset_size * sizeof(T);
if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
Expand All @@ -63,7 +70,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
thrust::fill(thrust::device,
p_output_ptr,
p_output_ptr + memset_size,
std::numeric_limits<T>::min());
std::numeric_limits<T>::lowest());
} else if (pool_type == "MIN") {
thrust::device_ptr<T> p_output_ptr(p_output);
thrust::fill(thrust::device,
Expand Down Expand Up @@ -91,7 +98,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid_tmp = (n + block - 1) / block;
int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
int64_t input_size = src_dims[0];
int64_t input_size = out_size <= 0 ? src_dims[0] : out_size;
if (pool_type == "SUM") {
GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>>
Expand All @@ -103,9 +110,6 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
<<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);

if (out_size > 0) {
input_size = out_size;
}
int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_max =
grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
Expand All @@ -117,9 +121,6 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
<<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);

if (out_size > 0) {
input_size = out_size;
}
int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
int64_t grid_min =
grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
Expand All @@ -130,12 +131,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
GraphSendRecvCUDAKernel<T, IndexT, GraphSendRecvSumCUDAFunctor<T, IndexT>>
<<<grid, block, 0, ctx.stream()>>>(
p_src, s_index, d_index, p_output, index_size, slice_size, functor);

dst_count->Resize({input_size});
ctx.template Alloc<int32_t>(dst_count);
int32_t* p_dst_count = dst_count->data<int32_t>();
if (out_size > 0) {
input_size = out_size;
}
int* p_dst_count = dst_count->data<int>();

#ifdef PADDLE_WITH_HIP
hipMemset(p_dst_count, 0, input_size * sizeof(int));
Expand All @@ -161,16 +159,29 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count) {
auto index_type = src_index.dtype();
auto& out_size_data = out_size.GetData();
if (index_type == phi::DataType::INT32) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(ctx,
x,
src_index,
dst_index,
pool_type,
out_size_data[0],
out,
dst_count);
} else if (index_type == phi::DataType::INT64) {
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(ctx,
x,
src_index,
dst_index,
pool_type,
out_size_data[0],
out,
dst_count);
}
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/graph_send_recv_kernel.h
Expand Up @@ -16,6 +16,7 @@

#include <string>

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
Expand All @@ -26,7 +27,7 @@ void GraphSendRecvKernel(const Context& ctx,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& pool_type,
int64_t out_size,
const IntArray& out_size,
DenseTensor* out,
DenseTensor* dst_count);

Expand Down
15 changes: 11 additions & 4 deletions paddle/phi/ops/compat/graph_send_recv_sig.cc
Expand Up @@ -18,10 +18,17 @@ namespace phi {

KernelSignature GraphSendRecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"},
{"pool_type", "out_size"},
{"Out", "Dst_count"});
if (ctx.HasInput("Out_size")) {
return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"},
{"pool_type", "Out_size"},
{"Out", "Dst_count"});
} else {
return KernelSignature("graph_send_recv",
{"X", "Src_index", "Dst_index"},
{"pool_type", "out_size"},
{"Out", "Dst_count"});
}
}

KernelSignature GraphSendRecvGradOpArgumentMapping(
Expand Down
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Expand Up @@ -78,6 +78,7 @@
import paddle.reader # noqa: F401
import paddle.static # noqa: F401
import paddle.vision # noqa: F401
import paddle.geometric # noqa: F401

from .tensor.attribute import is_complex # noqa: F401
from .tensor.attribute import is_integer # noqa: F401
Expand Down

0 comments on commit 34b4355

Please sign in to comment.