From f961f9b31156e9b6796e2dcf707d68f91a291ed8 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Mon, 6 Jun 2022 08:37:11 +0000 Subject: [PATCH] fix broadcast error, add kernel sig, register e_grad, change unit test --- .../fluid/operators/graph_send_e_recv_op.cc | 1 + .../impl/graph_send_e_recv_kernel_impl.h | 2 +- .../phi/ops/compat/graph_send_e_recv_sig.cc | 42 ++++++++ .../unittests/test_graph_send_e_recv_op.py | 100 +++++++++--------- 4 files changed, 95 insertions(+), 50 deletions(-) create mode 100644 paddle/phi/ops/compat/graph_send_e_recv_sig.cc diff --git a/paddle/fluid/operators/graph_send_e_recv_op.cc b/paddle/fluid/operators/graph_send_e_recv_op.cc index 8b1e0f87c9f79..994153f50af81 100644 --- a/paddle/fluid/operators/graph_send_e_recv_op.cc +++ b/paddle/fluid/operators/graph_send_e_recv_op.cc @@ -124,6 +124,7 @@ class GraphSendERecvGradOpMaker : public framework::SingleGradOpMaker { op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("E"), this->InputGrad("E")); op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/phi/kernels/impl/graph_send_e_recv_kernel_impl.h b/paddle/phi/kernels/impl/graph_send_e_recv_kernel_impl.h index 5c8d5340693ae..23bcae5e077fa 100644 --- a/paddle/phi/kernels/impl/graph_send_e_recv_kernel_impl.h +++ b/paddle/phi/kernels/impl/graph_send_e_recv_kernel_impl.h @@ -66,7 +66,7 @@ inline BroadCastInfo CalcBCastInfo(const phi::DDim& l_dims, (l_dims.size() - 1 - i < 1) ? 1 : l_dims[l_dims.size() - 1 - i]; const int dr = (r_dims.size() - 1 - i < 1) ? 1 : r_dims[r_dims.size() - 1 - i]; - for (int j = 0; j < std::max(dl, dr); j++) { + for (int j = 1; j < std::max(dl, dr); j++) { for (int k = 0; k < out_len; k++) { binfo.l_offset.emplace_back(binfo.l_offset[k] + j * (j < dl) * stride_l); diff --git a/paddle/phi/ops/compat/graph_send_e_recv_sig.cc b/paddle/phi/ops/compat/graph_send_e_recv_sig.cc new file mode 100644 index 0000000000000..a89708cf35736 --- /dev/null +++ b/paddle/phi/ops/compat/graph_send_e_recv_sig.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GraphSendERecvOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("graph_send_e_recv", + {"X", "E", "Src_index", "Dst_index"}, + {"compute_type", "pool_type", "out_size"}, + {"Out", "Dst_count"}); +} + +KernelSignature GraphSendERecvGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "graph_send_e_recv_grad", + {"X", "E", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"}, + {"compute_type", "pool_type"}, + {"X@GRAD", "E@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(graph_send_e_recv, + phi::GraphSendERecvOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(graph_send_e_recv_grad, + phi::GraphSendERecvGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_e_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_e_recv_op.py index 0a1a995fdf39b..f2abf65e05361 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_send_e_recv_op.py +++ b/python/paddle/fluid/tests/unittests/test_graph_send_e_recv_op.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest - import numpy as np import paddle import paddle.fluid as fluid @@ -22,6 +21,40 @@ from op_test import OpTest +def get_broadcast_shape(shp1, shp2): + pad_shp1, pad_shp2 = shp1, shp2 + if len(shp1) > len(shp2): + pad_shp2 = [1, ] * (len(shp1) - len(shp2)) + shp2 + elif len(shp1) < len(shp2): + pad_shp1 = [1, ] * (len(shp2) - len(shp1)) + shp1 + for d1, d2 in zip(pad_shp1, pad_shp2): + if d1 != d2 and d1 != 1 and d2 != 1: + raise ValueError + rst = [max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)] + return rst + + +def compute_graph_send_e_recv_for_sum(inputs, attributes): + x = inputs['X'] + e = inputs['E'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + compute_type = attributes['compute_type'] + + gather_x = x[src_index] + out_shp = [x.shape[0], ] + get_broadcast_shape(x.shape[1:], e.shape[1:]) + results = np.zeros(out_shp, dtype=x.dtype) + + # Calculate forward output + if compute_type == 'ADD': + x_compute_e = gather_x + e + elif compute_type == 'MUL': + x_compute_e = gather_x * e + for index, s_id in enumerate(dst_index): + results[s_id, :] += x_compute_e[index, :] + return results + + class TestGraphSendERecvSumOp(OpTest): def setUp(self): paddle.enable_static() @@ -50,28 +83,31 @@ def set_config(self): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X', 'E'], 'Out') + -def TestSumCase1(TestGraphSendERecvSumOp): +class TestSumCase1(TestGraphSendERecvSumOp): def set_config(self): - self.x = np.random.random((10, 20)).astype("float64") - self.e = np.random.random((15, 1)).astype("float64") - index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.x = np.random.random((100, 20)).astype("float64") + self.e = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 100, (150, 2)).astype(np.int64) self.src_index = index[:, 0] self.dst_index = index[:, 1] self.compute_type = 'ADD' -def TestSumCase2(TestGraphSendERecvSumOp): +class TestSumCase2(TestGraphSendERecvSumOp): def set_config(self): - self.x = np.random.random((10, 1)).astype("float64") + self.x = np.random.random((100, 1)).astype("float64") self.e = np.random.random((15, 20)).astype("float64") - index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) self.src_index = index[:, 0] self.dst_index = index[:, 1] self.compute_type = 'ADD' -def TestSumCase3(TestGraphSendERecvSumOp): +class TestSumCase3(TestGraphSendERecvSumOp): def set_config(self): self.x = np.random.random((10, 20)).astype("float64") self.e = np.random.random((15, 20)).astype("float64") @@ -81,55 +117,21 @@ def set_config(self): self.compute_type = 'MUL' -def TestSumCase4(TestGraphSendERecvSumOp): +class TestSumCase4(TestGraphSendERecvSumOp): def set_config(self): self.x = np.random.random((10, 20)).astype("float64") - self.e = np.random.random((15, 1)).astype("float64") - index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + self.e = np.random.random((150, 1)).astype("float64") + index = np.random.randint(0, 10, (150, 2)).astype(np.int64) self.src_index = index[:, 0] self.dst_index = index[:, 1] self.compute_type = 'MUL' -def TestSumCase5(TestGraphSendERecvSumOp): +class TestSumCase5(TestGraphSendERecvSumOp): def set_config(self): - self.x = np.random.random((10, 1)).astype("float64") + self.x = np.random.random((100, 1)).astype("float64") self.e = np.random.random((15, 20)).astype("float64") - index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + index = np.random.randint(0, 100, (15, 2)).astype(np.int64) self.src_index = index[:, 0] self.dst_index = index[:, 1] self.compute_type = 'MUL' - - -def get_broadcast_shape(shp1, shp2): - pad_shp1, pad_shp2 = shp1, shp2 - if len(shp1) > len(shp2): - pad_shp2 = [1, ] * (len(shp1) - len(shp2)) + shp2 - elif len(shp1) < len(shp2): - pad_shp1 = [1, ] * (len(shp2) - len(shp1)) + shp1 - for d1, d2 in zip(pad_shp1, pad_shp2): - if d1 != d2 and d1 != 1 and d2 != 1: - raise ValueError - rst = [max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)] - return rst - - -def compute_graph_send_e_recv_for_sum(inputs, attributes): - x = inputs['X'] - e = inputs['E'] - src_index = inputs['Src_index'] - dst_index = inputs['Dst_index'] - compute_type = attributes['compute_type'] - - gather_x = x[src_index] - out_shp = [x.shape[0], ] + get_broadcast_shape(x.shape[1:], e.shape[1:]) - results = np.zeros(out_shp, dtype=x.dtype) - - # Calculate forward output - if compute_type == 'ADD': - x_compute_e = gather_x + e - elif compute_type == 'MUL': - x_compute_e = gather_x * e - for index, s_id in enumerate(dst_index): - results[s_id, :] += x_compute_e[index, :] - return results