From 07005a65822c5fa0e6e0abfbcf0585ef62dd9963 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 22 Jul 2022 02:57:33 +0000 Subject: [PATCH 1/9] change out_size to INTArray --- paddle/fluid/operators/graph_send_recv_op.cc | 12 ++- paddle/fluid/pybind/op_function_generator.h | 1 + paddle/phi/api/yaml/legacy_api.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 4 +- paddle/phi/infermeta/ternary.cc | 18 +--- paddle/phi/infermeta/ternary.h | 3 +- .../phi/kernels/cpu/graph_send_recv_kernel.cc | 37 ++++++-- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 28 ++++-- paddle/phi/kernels/graph_send_recv_kernel.h | 3 +- paddle/phi/ops/compat/graph_send_recv_sig.cc | 15 +++- .../unittests/test_graph_send_recv_op.py | 18 ++-- .../incubate/operators/graph_send_recv.py | 86 +++++++++++++------ 12 files changed, 153 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index d9c0ec5171464..8e515a02b40a0 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -58,6 +58,10 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor with data type float32, float64, int32, int64."); AddInput("Src_index", "The source index tensor."); AddInput("Dst_index", "The destination index tensor."); + AddInput("OutSizeTensor", + "(Tensor, optional). The 0th dimension of the output." + "It has a higher priority than Attr(out_size).") + .AsDispensable(); AddOutput("Out", "Output tensor of graph_send_recv op."); AddOutput("Dst_count", "Count tensor of Dst_index, mainly for MEAN pool_type.") @@ -68,12 +72,12 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { "tensors of Dst_index.") .SetDefault("SUM") .InEnum({"SUM", "MEAN", "MIN", "MAX"}); - AddAttr( + AddAttr>( "out_size", - "(int64_t, default 0)" + "(vector, default {0})" "Define the first dimension of Output tensor." - "If set default 0, then the shape of Out is the same with X.") - .SetDefault(0); + "If set default {0}, then the shape of Out is the same with X.") + .SetDefault({0}); AddComment(R"DOC( Graph Learning Send_Recv combine operator. diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 590d9d2f83e8b..04e54c994053d 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -225,6 +225,7 @@ std::map> op_ins_map = { "Bias3", "Mean3", "Var3"}}, + {"graph_send_recv", {"X", "Src_index", "Dst_index", "OutSizeTensor"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0d0fd74c17aa7..df193f0842747 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -886,7 +886,7 @@ backward : gelu_grad - api : graph_send_recv - args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0) + args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type, IntArray out_size) output : Tensor(out), Tensor(dst_count) infer_meta : func : GraphSendRecvInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6df4883145620..21a6471f36760 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -828,8 +828,8 @@ func : gelu_grad - backward_api : graph_send_recv_grad - forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", int64_t out_size = 0) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM") + forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type, IntArray out_size) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type, IntArray out_size) output : Tensor(x_grad) infer_meta : func : GeneralUnaryGradInferMeta diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 9f65de0f0aa70..5e3c09379d6fb 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -302,7 +302,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, - int64_t out_size, + const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count) { auto src_index_dims = src_index.dims(); @@ -345,23 +345,11 @@ void GraphSendRecvInferMeta(const MetaTensor& x, "Src_index and Dst_index should have the same shape.")); auto dims = x.dims(); - if (out_size <= 0) { - out->set_dims(dims); - } else { - std::vector dims_ = phi::vectorize(dims); - if (dims_.size() > 0) { - dims_[0] = out_size; - } - out->set_dims(phi::make_ddim(dims_)); - } + out->set_dims(dims); out->set_dtype(x.dtype()); if (pool_type == "MEAN") { - if (out_size <= 0) { - dst_count->set_dims({dims[0]}); - } else { - dst_count->set_dims({out_size}); - } + dst_count->set_dims({dims[0]}); dst_count->set_dtype(DataType::INT32); } } diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 40461d299fb01..01fc18f16f212 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/meta_tensor.h" namespace phi { @@ -65,7 +66,7 @@ void GraphSendRecvInferMeta(const MetaTensor& x, const MetaTensor& src_index, const MetaTensor& dst_index, const std::string& pool_type, - int64_t out_size, + const IntArray& out_size, MetaTensor* out, MetaTensor* dst_count); diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index e4034230c7866..7a7060a36aa28 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -88,11 +88,10 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; - ctx.template Alloc(out); - T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { + out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } @@ -102,6 +101,8 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, memset_size *= src_dims[i]; } } + ctx.template Alloc(out); + T* p_output = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); memset(p_output, 0, memset_bytes); @@ -109,6 +110,7 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); + if (pool_type == "SUM") { GraphSendRecvCpuLoop>( src_dims[0], index_size, s_index, d_index, x, out, pool_type); @@ -119,10 +121,14 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, GraphSendRecvCpuLoop>( src_dims[0], index_size, s_index, d_index, x, out, pool_type); } else if (pool_type == "MEAN") { + if (out_size <= 0) { + dst_count->Resize({src_dims[0]}); + } + int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; ctx.template Alloc(dst_count); int* p_dst_count = dst_count->data(); - memset(p_dst_count, 0, src_dims[0] * sizeof(int)); - GraphSendRecvCpuLoop>(src_dims[0], + memset(p_dst_count, 0, input_size * sizeof(int)); + GraphSendRecvCpuLoop>(input_size, index_size, s_index, d_index, @@ -139,16 +145,29 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - int64_t out_size, + const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); + auto& out_size_data = out_size.GetData(); if (index_type == phi::DataType::INT32) { - GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); + GraphSendRecvOpKernelLaunchHelper(ctx, + x, + src_index, + dst_index, + pool_type, + out_size_data[0], + out, + dst_count); } else if (index_type == phi::DataType::INT64) { - GraphSendRecvOpKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); + GraphSendRecvOpKernelLaunchHelper(ctx, + x, + src_index, + dst_index, + pool_type, + out_size_data[0], + out, + dst_count); } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 7ecf352ffe996..748257f8fa732 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -37,11 +37,10 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, DenseTensor* out, DenseTensor* dst_count = nullptr) { const int& index_size = src_index.dims()[0]; - ctx.template Alloc(out); - T* p_output = out->data(); const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { + out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } @@ -51,6 +50,8 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, memset_size *= src_dims[i]; } } + ctx.template Alloc(out); + T* p_output = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); if (pool_type == "SUM" || pool_type == "MEAN") { #ifdef PADDLE_WITH_HIP @@ -161,16 +162,29 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - int64_t out_size, + const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count) { auto index_type = src_index.dtype(); + auto& out_size_data = out_size.GetData(); if (index_type == phi::DataType::INT32) { - GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); + GraphSendRecvOpCUDAKernelLaunchHelper(ctx, + x, + src_index, + dst_index, + pool_type, + out_size_data[0], + out, + dst_count); } else if (index_type == phi::DataType::INT64) { - GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count); + GraphSendRecvOpCUDAKernelLaunchHelper(ctx, + x, + src_index, + dst_index, + pool_type, + out_size_data[0], + out, + dst_count); } } diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h index 8f635225b75a4..cd625c92b93ea 100644 --- a/paddle/phi/kernels/graph_send_recv_kernel.h +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -26,7 +27,7 @@ void GraphSendRecvKernel(const Context& ctx, const DenseTensor& src_index, const DenseTensor& dst_index, const std::string& pool_type, - int64_t out_size, + const IntArray& out_size, DenseTensor* out, DenseTensor* dst_count); diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index 9df2cf4d0fe91..236f5c4aa9c3c 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -18,10 +18,17 @@ namespace phi { KernelSignature GraphSendRecvOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("graph_send_recv", - {"X", "Src_index", "Dst_index"}, - {"pool_type", "out_size"}, - {"Out", "Dst_count"}); + if (ctx.HasInput("OutSizeTensor")) { + return KernelSignature("graph_send_recv", + {"X", "Src_index", "Dst_index"}, + {"pool_type", "OutSizeTensor"}, + {"Out", "Dst_count"}); + } else { + return KernelSignature("graph_send_recv", + {"X", "Src_index", "Dst_index"}, + {"pool_type", "out_size"}, + {"Out", "Dst_count"}); + } } KernelSignature GraphSendRecvGradOpArgumentMapping( diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index c0fdb134f16d6..bcfaf523777be 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -54,14 +54,19 @@ def setUp(self): def test_check_output(self): self.check_output(check_eager=True) + # self.check_output(check_eager=False) - def test_check_grad(self): - self.check_grad(['X'], - 'Out', - user_defined_grads=[self.gradient], - check_eager=True) + # def test_check_grad(self): + # self.check_grad(['X'], + # 'Out', + # user_defined_grads=[self.gradient], + # check_eager=True) + # self.check_grad(['X'], 'Out', + # user_defined_grads=[self.gradient], + # check_eager=False) +""" class TestGraphSendRecvMinOp(OpTest): def setUp(self): @@ -176,6 +181,7 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): count[s_id] += 1 return results, count +""" def compute_graph_send_recv_for_min_max(inputs, attributes): @@ -223,6 +229,7 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): return results, gradient / results.size +""" class API_GraphSendRecvOpTest(unittest.TestCase): def test_static(self): @@ -365,3 +372,4 @@ def test_api_eager_dygraph(self): if __name__ == '__main__': unittest.main() +""" diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index e9937558e9b3a..e209e994b7581 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -14,8 +14,8 @@ from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode -from paddle.fluid.data_feeder import check_variable_and_dtype -from paddle.fluid import core +from paddle.fluid.framework import Variable +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from paddle import _C_ops @@ -63,14 +63,17 @@ def graph_send_recv(x, The available data type is int32, int64. pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. Default value is `sum`. - out_size (int64|None): We can set `out_size` to get necessary output shape. If not set, then this - attribute will not be used. If set, it should be equal with or larger than - max(dst_index) + 1. + out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or + out_size is smaller or equal to 0, then this input will not be used. + Otherwise, `out_size` should be equal with or larger than + max(dst_index) + 1. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. + out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. + If `out_size` is set correctly, then it should have the same shape as `x` except + the 0th dimension. Examples: @@ -109,28 +112,24 @@ def graph_send_recv(x, # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. - if out_size is None or out_size <= 0: + if out_size is None: if _in_legacy_dygraph(): out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type', pool_type.upper()) return out if in_dygraph_mode(): return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, - pool_type.upper(), 0) + pool_type.upper(), [0]) else: if _in_legacy_dygraph(): + out_size = convert_out_size_to_list(out_size) out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, 'pool_type', pool_type.upper(), 'out_size', out_size) return out if in_dygraph_mode(): - if isinstance(out_size, core.eager.Tensor): - if (out_size.size < 1): - raise ValueError( - "out_size should be long type, but received Tensor type." - ) - out_size = out_size.numpy()[0] + out_size = convert_out_size_to_list(out_size) return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, pool_type.upper(), out_size) @@ -141,25 +140,62 @@ def graph_send_recv(x, "graph_send_recv") check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv") + check_type(out_size, 'out_size', + (int, np.int32, np.int64, Variable, list, tuple), + 'graph_send_recv') + if isinstance(out_size, Variable): + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], + 'graph_send_recv') helper = LayerHelper("graph_send_recv", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) dst_count = helper.create_variable_for_type_inference(dtype="int32", stop_gradient=True) + + inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} + attrs = {"pool_type": pool_type.upper()} + get_out_size_tensor_inputs(inputs=inputs, + attrs=attrs, + out_size=out_size, + op_type='graph_send_recv') + helper.append_op(type="graph_send_recv", - inputs={ - "X": x, - "Src_index": src_index, - "Dst_index": dst_index - }, + inputs=inputs, outputs={ "Out": out, "Dst_count": dst_count }, - attrs={ - "pool_type": - pool_type.upper(), - "out_size": - 0 if out_size is None or out_size <= 0 else out_size - }) + attrs=attrs) return out + + +def convert_out_size_to_list(out_size): + """ + Convert out_size(int, np.int32, np.int64, Variable) to list + in imperative mode. + """ + if isinstance(out_size, (int, np.int32, np.int64)): + out_size = [out_size] + else: + out_size = [out_size.numpy().astype(int)[0]] + return out_size + + +def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): + """ + Convert out_size(int, np.int32, np.int64, Variable) to inputs + and attrs in static mode. + """ + if isinstance(out_size, (int, np.int32, np.int64)): + out_size = [out_size] + attrs['out_size'] = out_size + elif isinstance(out_size, Variable): + out_size.stop_gradient = True + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], + 'fill_constant', + '(When type of out_size in' + op_type + ' is Variable.)') + if (convert_dtype(out_size.dtype) == 'int64'): + out_size = cast(out_size, 'int32') + inputs["OutSizeTensor"] = out_size + else: + raise TypeError("Out_size only supports Variable or int.") From 00e8cdddc61697ae50209bdd0e03dfc83859bccb Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 22 Jul 2022 07:06:33 +0000 Subject: [PATCH 2/9] fix out_size eager bug --- paddle/phi/api/yaml/legacy_backward.yaml | 2 +- .../phi/kernels/cpu/graph_send_recv_kernel.cc | 13 +++-- .../phi/kernels/gpu/graph_send_recv_kernel.cu | 22 ++++---- .../unittests/test_graph_send_recv_op.py | 18 ++----- .../incubate/operators/graph_send_recv.py | 52 +++++++++---------- 5 files changed, 47 insertions(+), 60 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 21a6471f36760..6e32177f97e43 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -829,7 +829,7 @@ - backward_api : graph_send_recv_grad forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type, IntArray out_size) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type, IntArray out_size) + args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type) output : Tensor(x_grad) infer_meta : func : GeneralUnaryGradInferMeta diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index 7a7060a36aa28..e1dbeb9042116 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -91,23 +91,28 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { - out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } } else { + // set out dim following out_size. + std::vector dims_ = phi::vectorize(src_dims); + if (dims_.size() > 0) { + dims_[0] = out_size; + } + out->Resize(phi::make_ddim(dims_)); memset_size = out_size; for (int i = 1; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } } + ctx.template Alloc(out); T* p_output = out->data(); const size_t& memset_bytes = memset_size * sizeof(T); memset(p_output, 0, memset_bytes); if (index_size == 0) return; - const IndexT* s_index = src_index.data(); const IndexT* d_index = dst_index.data(); @@ -121,10 +126,8 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, GraphSendRecvCpuLoop>( src_dims[0], index_size, s_index, d_index, x, out, pool_type); } else if (pool_type == "MEAN") { - if (out_size <= 0) { - dst_count->Resize({src_dims[0]}); - } int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; + dst_count->Resize({input_size}); ctx.template Alloc(dst_count); int* p_dst_count = dst_count->data(); memset(p_dst_count, 0, input_size * sizeof(int)); diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index 748257f8fa732..ea4c0c819d34f 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -40,11 +40,16 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { - out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } } else { + // set out dim following out_size. + std::vector dims_ = phi::vectorize(src_dims); + if (dims_.size() > 0) { + dims_[0] = out_size; + } + out->Resize(phi::make_ddim(dims_)); memset_size = out_size; for (int i = 1; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; @@ -92,7 +97,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; int64_t grid_tmp = (n + block - 1) / block; int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; - int64_t input_size = src_dims[0]; + int64_t input_size = out_size <= 0 ? src_dims[0] : out_size; if (pool_type == "SUM") { GraphSendRecvSumCUDAFunctor functor; GraphSendRecvCUDAKernel> @@ -104,9 +109,6 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, <<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - if (out_size > 0) { - input_size = out_size; - } int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_max = grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; @@ -118,9 +120,6 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, <<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - if (out_size > 0) { - input_size = out_size; - } int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; int64_t grid_min = grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; @@ -131,12 +130,9 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, GraphSendRecvCUDAKernel> <<>>( p_src, s_index, d_index, p_output, index_size, slice_size, functor); - + dst_count->Resize({input_size}); ctx.template Alloc(dst_count); - int32_t* p_dst_count = dst_count->data(); - if (out_size > 0) { - input_size = out_size; - } + int* p_dst_count = dst_count->data(); #ifdef PADDLE_WITH_HIP hipMemset(p_dst_count, 0, input_size * sizeof(int)); diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index bcfaf523777be..c0fdb134f16d6 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -54,19 +54,14 @@ def setUp(self): def test_check_output(self): self.check_output(check_eager=True) - # self.check_output(check_eager=False) - # def test_check_grad(self): - # self.check_grad(['X'], - # 'Out', - # user_defined_grads=[self.gradient], - # check_eager=True) - # self.check_grad(['X'], 'Out', - # user_defined_grads=[self.gradient], - # check_eager=False) + def test_check_grad(self): + self.check_grad(['X'], + 'Out', + user_defined_grads=[self.gradient], + check_eager=True) -""" class TestGraphSendRecvMinOp(OpTest): def setUp(self): @@ -181,7 +176,6 @@ def compute_graph_send_recv_for_sum_mean(inputs, attributes): count[s_id] += 1 return results, count -""" def compute_graph_send_recv_for_min_max(inputs, attributes): @@ -229,7 +223,6 @@ def compute_graph_send_recv_for_min_max(inputs, attributes): return results, gradient / results.size -""" class API_GraphSendRecvOpTest(unittest.TestCase): def test_static(self): @@ -372,4 +365,3 @@ def test_api_eager_dygraph(self): if __name__ == '__main__': unittest.main() -""" diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index e209e994b7581..2f458d1322fa5 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode from paddle.fluid.framework import Variable @@ -112,27 +114,17 @@ def graph_send_recv(x, # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. - if out_size is None: - if _in_legacy_dygraph(): - out, tmp = _C_ops.graph_send_recv(x, src_index, dst_index, - 'pool_type', pool_type.upper()) - return out - if in_dygraph_mode(): - return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, - pool_type.upper(), [0]) - else: - if _in_legacy_dygraph(): - out_size = convert_out_size_to_list(out_size) - out, tmp = _C_ops.graph_send_recv(x, src_index, - dst_index, 'pool_type', - pool_type.upper(), 'out_size', - out_size) - return out - if in_dygraph_mode(): - out_size = convert_out_size_to_list(out_size) - return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, - pool_type.upper(), - out_size) + if _in_legacy_dygraph(): + out_size = convert_out_size_to_list(out_size) + out, tmp = _C_ops.graph_send_recv(x, src_index, + dst_index, None, 'pool_type', + pool_type.upper(), 'out_size', + out_size) + return out + if in_dygraph_mode(): + out_size = convert_out_size_to_list(out_size) + return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, + pool_type.upper(), out_size) check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), "graph_send_recv") @@ -140,9 +132,10 @@ def graph_send_recv(x, "graph_send_recv") check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv") - check_type(out_size, 'out_size', - (int, np.int32, np.int64, Variable, list, tuple), - 'graph_send_recv') + if out_size: + check_type(out_size, 'out_size', + (int, np.int32, np.int64, Variable, list, tuple), + 'graph_send_recv') if isinstance(out_size, Variable): check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], 'graph_send_recv') @@ -174,7 +167,9 @@ def convert_out_size_to_list(out_size): Convert out_size(int, np.int32, np.int64, Variable) to list in imperative mode. """ - if isinstance(out_size, (int, np.int32, np.int64)): + if out_size is None: + out_size = [0] + elif isinstance(out_size, (int, np.int32, np.int64)): out_size = [out_size] else: out_size = [out_size.numpy().astype(int)[0]] @@ -186,9 +181,10 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): Convert out_size(int, np.int32, np.int64, Variable) to inputs and attrs in static mode. """ - if isinstance(out_size, (int, np.int32, np.int64)): - out_size = [out_size] - attrs['out_size'] = out_size + if out_size is None: + attrs['out_size'] = [0] + elif isinstance(out_size, (int, np.int32, np.int64)): + attrs['out_size'] = [out_size] elif isinstance(out_size, Variable): out_size.stop_gradient = True check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], From 2c8d2cb1e95b25c991ee84f0330e2aa9bbc1aaed Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 22 Jul 2022 09:41:13 +0000 Subject: [PATCH 3/9] add unittest for out_size tensor --- .../unittests/test_graph_send_recv_op.py | 152 ++++++++++-------- 1 file changed, 86 insertions(+), 66 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index c0fdb134f16d6..ebb60eae8514c 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -268,31 +268,26 @@ def test_static(self): {}\n{}, check diff!".format(np_res, ret_res)) def test_dygraph(self): - device = paddle.CPUPlace() - with paddle.fluid.dygraph.guard(device): - x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), - dtype="float32") - src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") - dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") - res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "sum") - res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "mean") - res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "max") - res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "min") + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "mean") + res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "max") + res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "min") - np_sum = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], - dtype="float32") - np_mean = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], - dtype="float32") - np_max = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], - dtype="float32") - np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], - dtype="float32") + np_sum = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + np_mean = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + np_max = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") - ret = [res_sum, res_mean, res_max, res_min] + ret = [res_sum, res_mean, res_max, res_min] for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): self.assertTrue( @@ -300,30 +295,26 @@ def test_dygraph(self): {}\n{}, check diff!".format(np_res, ret_res)) def test_int32_input(self): - device = paddle.CPUPlace() - with paddle.fluid.dygraph.guard(device): - x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), - dtype="int32") - src_index = paddle.to_tensor(np.array([0, 1, 2, 0, 1]), - dtype="int32") - dst_index = paddle.to_tensor(np.array([1, 2, 1, 0, 1]), - dtype="int32") - res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "sum") - res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "mean") - res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "max") - res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "min") + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), + dtype="int32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0, 1]), dtype="int32") + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "mean") + res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "max") + res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "min") - np_sum = np.array([[0, 2, 3], [3, 12, 14], [1, 4, 5]], - dtype="int32") - np_mean = np.array([[0, 2, 3], [1, 4, 4], [1, 4, 5]], dtype="int32") - np_max = np.array([[0, 2, 3], [2, 6, 6], [1, 4, 5]], dtype="int32") - np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="int32") + np_sum = np.array([[0, 2, 3], [3, 12, 14], [1, 4, 5]], dtype="int32") + np_mean = np.array([[0, 2, 3], [1, 4, 4], [1, 4, 5]], dtype="int32") + np_max = np.array([[0, 2, 3], [2, 6, 6], [1, 4, 5]], dtype="int32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="int32") - ret = [res_sum, res_mean, res_max, res_min] + ret = [res_sum, res_mean, res_max, res_min] for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): self.assertTrue( @@ -331,31 +322,60 @@ def test_int32_input(self): {}\n{}, check diff!".format(np_res, ret_res)) def test_set_outsize_gpu(self): - if paddle.fluid.core.is_compiled_with_cuda(): - x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), - dtype="float32") - src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") - dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32") - res = paddle.incubate.graph_send_recv(x, src_index, dst_index, - "sum") - out_size = paddle.max(dst_index) + 1 - res_set_outsize = paddle.incubate.graph_send_recv( - x, src_index, dst_index, "sum", out_size) - - np_res = np.array([[0, 2, 3], [1, 6, 8], [0, 0, 0]], - dtype="float32") - np_res_set_outsize = np.array([[0, 2, 3], [1, 6, 8]], - dtype="float32") - - self.assertTrue( - np.allclose(np_res, res, atol=1e-6), "two value is\ + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), + dtype="float32") + src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32") + res = paddle.incubate.graph_send_recv(x, src_index, dst_index, "sum") + out_size = paddle.max(dst_index) + 1 + res_set_outsize = paddle.incubate.graph_send_recv( + x, src_index, dst_index, "sum", out_size) + + np_res = np.array([[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + np_res_set_outsize = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") + + self.assertTrue( + np.allclose(np_res, res, atol=1e-6), "two value is\ {}\n{}, check diff!".format(np_res, res)) - self.assertTrue( - np.allclose(np_res_set_outsize, res_set_outsize, atol=1e-6), - "two value is\ + self.assertTrue( + np.allclose(np_res_set_outsize, res_set_outsize, atol=1e-6), + "two value is\ {}\n{}, check diff!".format(np_res_set_outsize, res_set_outsize)) + def test_out_size_tensor_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[3], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[3], dtype="int32") + out_size = paddle.static.data(name="out_size", + shape=[1], + dtype="int32") + + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum", out_size) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]], dtype='float32') + data2 = np.array([0, 0, 1], dtype="int32") + data3 = np.array([0, 1, 1], dtype="int32") + data4 = np.array([2], dtype="int32") + + np_sum = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'src': data2, + 'dst': data3, + 'out_size': data4, + }, + fetch_list=[res_sum]) + self.assertTrue( + np.allclose(np_sum, ret[0], atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_sum, ret[0])) + def test_api_eager_dygraph(self): with _test_eager_guard(): self.test_dygraph() From 21318863c82e3e9af2d89ea55eaffb00277338dc Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 25 Jul 2022 07:14:35 +0000 Subject: [PATCH 4/9] add deprecated for paddle.incubate.graph_send_recv, add paddle.geometric.send_u_recv and unittests --- python/paddle/__init__.py | 1 + .../unittests/test_graph_send_recv_op.py | 156 ++++++++++++++++- python/paddle/geometric/__init__.py | 19 ++ .../geometric/message_passing/__init__.py | 15 ++ .../geometric/message_passing/send_u_recv.py | 162 ++++++++++++++++++ .../paddle/geometric/message_passing/utils.py | 52 ++++++ .../incubate/operators/graph_send_recv.py | 14 +- python/setup.py.in | 2 + 8 files changed, 414 insertions(+), 7 deletions(-) create mode 100644 python/paddle/geometric/__init__.py create mode 100644 python/paddle/geometric/message_passing/__init__.py create mode 100644 python/paddle/geometric/message_passing/send_u_recv.py create mode 100644 python/paddle/geometric/message_passing/utils.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6e47f4f9eab43..a518fe4c84db7 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -78,6 +78,7 @@ import paddle.reader # noqa: F401 import paddle.static # noqa: F401 import paddle.vision # noqa: F401 +import paddle.geometric # noqa: F401 from .tensor.attribute import is_complex # noqa: F401 from .tensor.attribute import is_integer # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py index ebb60eae8514c..73c1525519066 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -28,8 +28,8 @@ def graph_send_recv_wrapper(x, pool_type="sum", out_size=None, name=None): - return paddle.incubate.graph_send_recv(x, src_index, dst_index, - pool_type.lower(), out_size, name) + return paddle.geometric.send_u_recv(x, src_index, dst_index, + pool_type.lower(), out_size, name) class TestGraphSendRecvMaxOp(OpTest): @@ -383,5 +383,157 @@ def test_api_eager_dygraph(self): self.test_set_outsize_gpu() +class API_GeometricSendURecvTest(unittest.TestCase): + + def test_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[4], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32") + + res_sum = paddle.geometric.send_u_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.geometric.send_u_recv(x, src_index, dst_index, + "mean") + res_max = paddle.geometric.send_u_recv(x, src_index, dst_index, + "max") + res_min = paddle.geometric.send_u_recv(x, src_index, dst_index, + "min") + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype='float32') + data2 = np.array([0, 1, 2, 0], dtype="int32") + data3 = np.array([1, 2, 1, 0], dtype="int32") + + np_sum = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], + dtype="float32") + np_mean = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], + dtype="float32") + np_max = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], + dtype="float32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], + dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'src': data2, + 'dst': data3 + }, + fetch_list=[res_sum, res_mean, res_max, res_min]) + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose(np_res, ret_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), + dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + res_sum = paddle.geometric.send_u_recv(x, src_index, dst_index, "sum") + res_mean = paddle.geometric.send_u_recv(x, src_index, dst_index, "mean") + res_max = paddle.geometric.send_u_recv(x, src_index, dst_index, "max") + res_min = paddle.geometric.send_u_recv(x, src_index, dst_index, "min") + + np_sum = np.array([[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + np_mean = np.array([[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + np_max = np.array([[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose(np_res, ret_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_int32_input(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), + dtype="int32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0, 1]), dtype="int32") + res_sum = paddle.geometric.send_u_recv(x, src_index, dst_index, "sum") + res_mean = paddle.geometric.send_u_recv(x, src_index, dst_index, "mean") + res_max = paddle.geometric.send_u_recv(x, src_index, dst_index, "max") + res_min = paddle.geometric.send_u_recv(x, src_index, dst_index, "min") + + np_sum = np.array([[0, 2, 3], [3, 12, 14], [1, 4, 5]], dtype="int32") + np_mean = np.array([[0, 2, 3], [1, 4, 4], [1, 4, 5]], dtype="int32") + np_max = np.array([[0, 2, 3], [2, 6, 6], [1, 4, 5]], dtype="int32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="int32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose(np_res, ret_res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_set_outsize_gpu(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), + dtype="float32") + src_index = paddle.to_tensor(np.array([0, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor(np.array([0, 1, 1]), dtype="int32") + res = paddle.geometric.send_u_recv(x, src_index, dst_index, "sum") + out_size = paddle.max(dst_index) + 1 + res_set_outsize = paddle.geometric.send_u_recv(x, src_index, dst_index, + "sum", out_size) + + np_res = np.array([[0, 2, 3], [1, 6, 8], [0, 0, 0]], dtype="float32") + np_res_set_outsize = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") + + self.assertTrue( + np.allclose(np_res, res, atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_res, res)) + self.assertTrue( + np.allclose(np_res_set_outsize, res_set_outsize, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res_set_outsize, + res_set_outsize)) + + def test_out_size_tensor_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[3], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[3], dtype="int32") + out_size = paddle.static.data(name="out_size", + shape=[1], + dtype="int32") + + res_sum = paddle.geometric.send_u_recv(x, src_index, dst_index, + "sum", out_size) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]], dtype='float32') + data2 = np.array([0, 0, 1], dtype="int32") + data3 = np.array([0, 1, 1], dtype="int32") + data4 = np.array([2], dtype="int32") + + np_sum = np.array([[0, 2, 3], [1, 6, 8]], dtype="float32") + + ret = exe.run(feed={ + 'x': data1, + 'src': data2, + 'dst': data3, + 'out_size': data4, + }, + fetch_list=[res_sum]) + self.assertTrue( + np.allclose(np_sum, ret[0], atol=1e-6), "two value is\ + {}\n{}, check diff!".format(np_sum, ret[0])) + + def test_api_eager_dygraph(self): + with _test_eager_guard(): + self.test_dygraph() + self.test_int32_input() + self.test_set_outsize_gpu() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/geometric/__init__.py b/python/paddle/geometric/__init__.py new file mode 100644 index 0000000000000..9e59062a7cc6a --- /dev/null +++ b/python/paddle/geometric/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .message_passing import send_u_recv # noqa: F401 + +__all__ = [ + 'send_u_recv', +] diff --git a/python/paddle/geometric/message_passing/__init__.py b/python/paddle/geometric/message_passing/__init__.py new file mode 100644 index 0000000000000..ede1803357738 --- /dev/null +++ b/python/paddle/geometric/message_passing/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .send_u_recv import send_u_recv # noqa: F401 diff --git a/python/paddle/geometric/message_passing/send_u_recv.py b/python/paddle/geometric/message_passing/send_u_recv.py new file mode 100644 index 0000000000000..4b7a02fc69347 --- /dev/null +++ b/python/paddle/geometric/message_passing/send_u_recv.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode +from paddle.fluid.framework import Variable +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype +from paddle import _C_ops + +from .utils import convert_out_size_to_list, get_out_size_tensor_inputs + + +def send_u_recv(x, + src_index, + dst_index, + pool_type="sum", + out_size=None, + name=None): + """ + + Graph Learning message passing api. + + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` + to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor + in different pooling types, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. + + .. code-block:: text + + Given: + + X = [[0, 2, 3], + [1, 4, 5], + [2, 6, 7]] + + src_index = [0, 1, 2, 0] + + dst_index = [1, 2, 1, 0] + + pool_type = "sum" + + out_size = None + + Then: + + Out = [[0, 2, 3], + [2, 8, 10], + [1, 4, 5]] + + Args: + x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. + src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. + pool_type (str): Different pooling types, including `sum`, `mean`, `max`, `min`. + Default value is `sum`. + out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or + out_size is smaller or equal to 0, then this input will not be used. + Otherwise, `out_size` should be equal with or larger than + max(dst_index) + 1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. + If `out_size` is set correctly, then it should have the same shape as `x` except + the 0th dimension. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") + # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out_size = paddle.max(dst_index) + 1 + out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum", out_size=out_size) + # Outputs: [[0., 2., 3.], [[2., 8., 10.]]] + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.geometric.send_u_recv(x, src_index, dst_index, pool_type="sum") + # Outputs: [[0., 2., 3.], [2., 8., 10.], [0., 0., 0.]] + + """ + + if pool_type not in ["sum", "mean", "max", "min"]: + raise ValueError( + "pool_type should be `sum`, `mean`, `max` or `min`, but received %s" + % pool_type) + + # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. + + if _in_legacy_dygraph(): + out_size = convert_out_size_to_list(out_size) + out, tmp = _C_ops.graph_send_recv(x, src_index, + dst_index, None, 'pool_type', + pool_type.upper(), 'out_size', + out_size) + return out + if in_dygraph_mode(): + out_size = convert_out_size_to_list(out_size) + return _C_ops.final_state_graph_send_recv(x, src_index, dst_index, + pool_type.upper(), out_size) + + check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), + "send_u_recv") + check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), + "send_u_recv") + check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), + "send_u_recv") + if out_size: + check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), + 'send_u_recv') + if isinstance(out_size, Variable): + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], + 'send_u_recv') + + helper = LayerHelper("send_u_recv", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + dst_count = helper.create_variable_for_type_inference(dtype="int32", + stop_gradient=True) + + inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} + attrs = {"pool_type": pool_type.upper()} + get_out_size_tensor_inputs(inputs=inputs, + attrs=attrs, + out_size=out_size, + op_type='send_u_recv') + + helper.append_op(type="graph_send_recv", + inputs=inputs, + outputs={ + "Out": out, + "Dst_count": dst_count + }, + attrs=attrs) + return out diff --git a/python/paddle/geometric/message_passing/utils.py b/python/paddle/geometric/message_passing/utils.py new file mode 100644 index 0000000000000..cd33d5d032d09 --- /dev/null +++ b/python/paddle/geometric/message_passing/utils.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle.fluid.framework import Variable +from paddle.fluid.data_feeder import check_dtype, convert_dtype +from paddle.fluid.layers.tensor import cast + + +def convert_out_size_to_list(out_size): + """ + Convert out_size(int, np.int32, np.int64, Variable) to list + in imperative mode. + """ + if out_size is None: + out_size = [0] + elif isinstance(out_size, (int, np.int32, np.int64)): + out_size = [out_size] + else: + out_size = [out_size.numpy().astype(int)[0]] + return out_size + + +def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): + """ + Convert out_size(int, np.int32, np.int64, Variable) to inputs + and attrs in static mode. + """ + if out_size is None: + attrs['out_size'] = [0] + elif isinstance(out_size, (int, np.int32, np.int64)): + attrs['out_size'] = [out_size] + elif isinstance(out_size, Variable): + out_size.stop_gradient = True + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], 'op_type', + '(When type of out_size in' + op_type + ' is Variable.)') + if (convert_dtype(out_size.dtype) == 'int64'): + out_size = cast(out_size, 'int32') + inputs["OutSizeTensor"] = out_size + else: + raise TypeError("Out_size only supports Variable or int.") diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 2f458d1322fa5..3b7469f7415ef 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -13,14 +13,20 @@ # limitations under the License. import numpy as np - from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode from paddle.fluid.framework import Variable from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype +from paddle.fluid.layers.tensor import cast from paddle import _C_ops +import paddle.utils.deprecated as deprecated +@deprecated( + since="2.4.0", + update_to="paddle.geometric.send_u_recv", + level=1, + reason="graph_send_recv in paddle.incubate will be removed in future") def graph_send_recv(x, src_index, dst_index, @@ -133,8 +139,7 @@ def graph_send_recv(x, check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv") if out_size: - check_type(out_size, 'out_size', - (int, np.int32, np.int64, Variable, list, tuple), + check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), 'graph_send_recv') if isinstance(out_size, Variable): check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], @@ -187,8 +192,7 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): attrs['out_size'] = [out_size] elif isinstance(out_size, Variable): out_size.stop_gradient = True - check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], - 'fill_constant', + check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], op_type, '(When type of out_size in' + op_type + ' is Variable.)') if (convert_dtype(out_size.dtype) == 'int64'): out_size = cast(out_size, 'int32') diff --git a/python/setup.py.in b/python/setup.py.in index c02ef7f017fca..6c71f07b8f5a7 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -400,6 +400,8 @@ packages=['paddle', 'paddle.device.cuda', 'paddle.version', 'paddle.profiler', + 'paddle.geometric', + 'paddle.geometric.message_passing', ] with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: From a47b28114f3281382fd8e92251057b9eab390d83 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 26 Jul 2022 06:25:25 +0000 Subject: [PATCH 5/9] fix lowest bug --- paddle/phi/kernels/gpu/graph_send_recv_funcs.h | 2 +- paddle/phi/kernels/gpu/graph_send_recv_kernel.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h index a93603ae18f1c..4be92ae18629c 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h @@ -81,7 +81,7 @@ __global__ void InputResetMaxCUDAKernel(T* output, size_t input_size, size_t slice_size) { CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { - if (*(output + i) == std::numeric_limits::min()) { + if (*(output + i) == std::numeric_limits::lowest()) { *(output + i) = 0; } } diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index ea4c0c819d34f..ea46c05a17907 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -69,7 +69,7 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, - std::numeric_limits::min()); + std::numeric_limits::lowest()); } else if (pool_type == "MIN") { thrust::device_ptr p_output_ptr(p_output); thrust::fill(thrust::device, From ad1692cce4c3fc09c9b1a0680fe2feec17b1a9e1 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 2 Aug 2022 03:04:36 +0000 Subject: [PATCH 6/9] fix according review comment --- paddle/fluid/operators/graph_send_recv_op.cc | 2 +- paddle/fluid/pybind/op_function_generator.h | 2 +- paddle/phi/infermeta/ternary.cc | 6 ++++-- paddle/phi/kernels/cpu/graph_send_recv_kernel.cc | 3 ++- paddle/phi/kernels/gpu/graph_send_recv_kernel.cu | 5 +++-- paddle/phi/ops/compat/graph_send_recv_sig.cc | 4 ++-- python/paddle/geometric/message_passing/utils.py | 2 +- python/paddle/incubate/operators/graph_send_recv.py | 2 +- 8 files changed, 15 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index 8e515a02b40a0..e9ba861c3b88b 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -58,7 +58,7 @@ class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor with data type float32, float64, int32, int64."); AddInput("Src_index", "The source index tensor."); AddInput("Dst_index", "The destination index tensor."); - AddInput("OutSizeTensor", + AddInput("Out_size", "(Tensor, optional). The 0th dimension of the output." "It has a higher priority than Attr(out_size).") .AsDispensable(); diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 04e54c994053d..2963a642f02f8 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -225,7 +225,7 @@ std::map> op_ins_map = { "Bias3", "Mean3", "Var3"}}, - {"graph_send_recv", {"X", "Src_index", "Dst_index", "OutSizeTensor"}}, + {"graph_send_recv", {"X", "Src_index", "Dst_index", "Out_size"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 5e3c09379d6fb..d655b559afbc2 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -345,11 +345,13 @@ void GraphSendRecvInferMeta(const MetaTensor& x, "Src_index and Dst_index should have the same shape.")); auto dims = x.dims(); - out->set_dims(dims); + std::vector dims_ = phi::vectorize(dims); + dims_[0] = -1; + out->set_dims(phi::make_ddim(dims_)); out->set_dtype(x.dtype()); if (pool_type == "MEAN") { - dst_count->set_dims({dims[0]}); + dst_count->set_dims({-1}); dst_count->set_dtype(DataType::INT32); } } diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc index e1dbeb9042116..d4b9c8c60e3f8 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -91,11 +91,12 @@ void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { + out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } } else { - // set out dim following out_size. + // Set out dim following out_size. std::vector dims_ = phi::vectorize(src_dims); if (dims_.size() > 0) { dims_[0] = out_size; diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu index ea46c05a17907..4dc2794d9c949 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -40,12 +40,13 @@ void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, const auto& src_dims = x.dims(); int64_t memset_size = 1; if (out_size <= 0) { + out->Resize(src_dims); for (int i = 0; i < src_dims.size(); ++i) { memset_size *= src_dims[i]; } } else { - // set out dim following out_size. - std::vector dims_ = phi::vectorize(src_dims); + // Set out dim following out_size. + std::vector dims_ = phi::vectorize(out->dims()); if (dims_.size() > 0) { dims_[0] = out_size; } diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc index 236f5c4aa9c3c..c8c15619d5d39 100644 --- a/paddle/phi/ops/compat/graph_send_recv_sig.cc +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -18,10 +18,10 @@ namespace phi { KernelSignature GraphSendRecvOpArgumentMapping( const ArgumentMappingContext& ctx) { - if (ctx.HasInput("OutSizeTensor")) { + if (ctx.HasInput("Out_size")) { return KernelSignature("graph_send_recv", {"X", "Src_index", "Dst_index"}, - {"pool_type", "OutSizeTensor"}, + {"pool_type", "Out_size"}, {"Out", "Dst_count"}); } else { return KernelSignature("graph_send_recv", diff --git a/python/paddle/geometric/message_passing/utils.py b/python/paddle/geometric/message_passing/utils.py index cd33d5d032d09..3614f829daf52 100644 --- a/python/paddle/geometric/message_passing/utils.py +++ b/python/paddle/geometric/message_passing/utils.py @@ -47,6 +47,6 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): '(When type of out_size in' + op_type + ' is Variable.)') if (convert_dtype(out_size.dtype) == 'int64'): out_size = cast(out_size, 'int32') - inputs["OutSizeTensor"] = out_size + inputs["Out_size"] = out_size else: raise TypeError("Out_size only supports Variable or int.") diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py index 3b7469f7415ef..132a6d4657ca1 100644 --- a/python/paddle/incubate/operators/graph_send_recv.py +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -196,6 +196,6 @@ def get_out_size_tensor_inputs(inputs, attrs, out_size, op_type): '(When type of out_size in' + op_type + ' is Variable.)') if (convert_dtype(out_size.dtype) == 'int64'): out_size = cast(out_size, 'int32') - inputs["OutSizeTensor"] = out_size + inputs["Out_size"] = out_size else: raise TypeError("Out_size only supports Variable or int.") From f8fac9abf405764637d8b7378d533f62bf9149c7 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 2 Aug 2022 06:24:53 +0000 Subject: [PATCH 7/9] add default value in yaml --- paddle/phi/api/yaml/legacy_api.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index df193f0842747..40661dc6396ed 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -886,7 +886,7 @@ backward : gelu_grad - api : graph_send_recv - args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type, IntArray out_size) + args : (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) output : Tensor(out), Tensor(dst_count) infer_meta : func : GraphSendRecvInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 6e32177f97e43..e38c85c85da71 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -828,8 +828,8 @@ func : gelu_grad - backward_api : graph_send_recv_grad - forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type, IntArray out_size) -> Tensor(out), Tensor(dst_count) - args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type) + forward : graph_send_recv (Tensor x, Tensor src_index, Tensor dst_index, str pool_type = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) + args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str pool_type = "SUM") output : Tensor(x_grad) infer_meta : func : GeneralUnaryGradInferMeta From 978f13966557f7603e521189363c3ee5f9f27f8f Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 2 Aug 2022 06:29:18 +0000 Subject: [PATCH 8/9] change api file name --- python/paddle/geometric/message_passing/__init__.py | 2 +- .../geometric/message_passing/{send_u_recv.py => send_recv.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename python/paddle/geometric/message_passing/{send_u_recv.py => send_recv.py} (100%) diff --git a/python/paddle/geometric/message_passing/__init__.py b/python/paddle/geometric/message_passing/__init__.py index ede1803357738..d9580e658650a 100644 --- a/python/paddle/geometric/message_passing/__init__.py +++ b/python/paddle/geometric/message_passing/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .send_u_recv import send_u_recv # noqa: F401 +from .send_recv import send_u_recv # noqa: F401 diff --git a/python/paddle/geometric/message_passing/send_u_recv.py b/python/paddle/geometric/message_passing/send_recv.py similarity index 100% rename from python/paddle/geometric/message_passing/send_u_recv.py rename to python/paddle/geometric/message_passing/send_recv.py From bd2da5237fe71d53887346841e97aefe3a7f593f Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 8 Aug 2022 13:08:44 +0000 Subject: [PATCH 9/9] change name --- python/paddle/geometric/message_passing/send_recv.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index 4b7a02fc69347..87379730a2a60 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -128,17 +128,17 @@ def send_u_recv(x, pool_type.upper(), out_size) check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), - "send_u_recv") + "graph_send_recv") check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), - "send_u_recv") + "graph_send_recv") check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), - "send_u_recv") + "graph_send_recv") if out_size: check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), - 'send_u_recv') + 'graph_send_recv') if isinstance(out_size, Variable): check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], - 'send_u_recv') + 'graph_send_recv') helper = LayerHelper("send_u_recv", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -150,7 +150,7 @@ def send_u_recv(x, get_out_size_tensor_inputs(inputs=inputs, attrs=attrs, out_size=out_size, - op_type='send_u_recv') + op_type='graph_send_recv') helper.append_op(type="graph_send_recv", inputs=inputs,