Skip to content

Commit

Permalink
fix broadcast error, add kernel sig, register e_grad, change unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jun 6, 2022
1 parent f1ea92f commit f961f9b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 50 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/graph_send_e_recv_op.cc
Expand Up @@ -124,6 +124,7 @@ class GraphSendERecvGradOpMaker : public framework::SingleGradOpMaker<T> {

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("E"), this->InputGrad("E"));
op->SetAttrMap(this->Attrs());
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/graph_send_e_recv_kernel_impl.h
Expand Up @@ -66,7 +66,7 @@ inline BroadCastInfo CalcBCastInfo(const phi::DDim& l_dims,
(l_dims.size() - 1 - i < 1) ? 1 : l_dims[l_dims.size() - 1 - i];
const int dr =
(r_dims.size() - 1 - i < 1) ? 1 : r_dims[r_dims.size() - 1 - i];
for (int j = 0; j < std::max(dl, dr); j++) {
for (int j = 1; j < std::max(dl, dr); j++) {
for (int k = 0; k < out_len; k++) {
binfo.l_offset.emplace_back(binfo.l_offset[k] +
j * (j < dl) * stride_l);
Expand Down
42 changes: 42 additions & 0 deletions paddle/phi/ops/compat/graph_send_e_recv_sig.cc
@@ -0,0 +1,42 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature GraphSendERecvOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("graph_send_e_recv",
{"X", "E", "Src_index", "Dst_index"},
{"compute_type", "pool_type", "out_size"},
{"Out", "Dst_count"});
}

KernelSignature GraphSendERecvGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"graph_send_e_recv_grad",
{"X", "E", "Src_index", "Dst_index", "Out", "Dst_count", "Out@GRAD"},
{"compute_type", "pool_type"},
{"X@GRAD", "E@GRAD"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(graph_send_e_recv,
phi::GraphSendERecvOpArgumentMapping);

PD_REGISTER_ARG_MAPPING_FN(graph_send_e_recv_grad,
phi::GraphSendERecvGradOpArgumentMapping);
100 changes: 51 additions & 49 deletions python/paddle/fluid/tests/unittests/test_graph_send_e_recv_op.py
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import unittest

import numpy as np
import paddle
import paddle.fluid as fluid
Expand All @@ -22,6 +21,40 @@
from op_test import OpTest


def get_broadcast_shape(shp1, shp2):
pad_shp1, pad_shp2 = shp1, shp2
if len(shp1) > len(shp2):
pad_shp2 = [1, ] * (len(shp1) - len(shp2)) + shp2
elif len(shp1) < len(shp2):
pad_shp1 = [1, ] * (len(shp2) - len(shp1)) + shp1
for d1, d2 in zip(pad_shp1, pad_shp2):
if d1 != d2 and d1 != 1 and d2 != 1:
raise ValueError
rst = [max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)]
return rst


def compute_graph_send_e_recv_for_sum(inputs, attributes):
x = inputs['X']
e = inputs['E']
src_index = inputs['Src_index']
dst_index = inputs['Dst_index']
compute_type = attributes['compute_type']

gather_x = x[src_index]
out_shp = [x.shape[0], ] + get_broadcast_shape(x.shape[1:], e.shape[1:])
results = np.zeros(out_shp, dtype=x.dtype)

# Calculate forward output
if compute_type == 'ADD':
x_compute_e = gather_x + e
elif compute_type == 'MUL':
x_compute_e = gather_x * e
for index, s_id in enumerate(dst_index):
results[s_id, :] += x_compute_e[index, :]
return results


class TestGraphSendERecvSumOp(OpTest):
def setUp(self):
paddle.enable_static()
Expand Down Expand Up @@ -50,28 +83,31 @@ def set_config(self):
def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X', 'E'], 'Out')


def TestSumCase1(TestGraphSendERecvSumOp):
class TestSumCase1(TestGraphSendERecvSumOp):
def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
self.e = np.random.random((15, 1)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.x = np.random.random((100, 20)).astype("float64")
self.e = np.random.random((150, 1)).astype("float64")
index = np.random.randint(0, 100, (150, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.compute_type = 'ADD'


def TestSumCase2(TestGraphSendERecvSumOp):
class TestSumCase2(TestGraphSendERecvSumOp):
def set_config(self):
self.x = np.random.random((10, 1)).astype("float64")
self.x = np.random.random((100, 1)).astype("float64")
self.e = np.random.random((15, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.compute_type = 'ADD'


def TestSumCase3(TestGraphSendERecvSumOp):
class TestSumCase3(TestGraphSendERecvSumOp):
def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
self.e = np.random.random((15, 20)).astype("float64")
Expand All @@ -81,55 +117,21 @@ def set_config(self):
self.compute_type = 'MUL'


def TestSumCase4(TestGraphSendERecvSumOp):
class TestSumCase4(TestGraphSendERecvSumOp):
def set_config(self):
self.x = np.random.random((10, 20)).astype("float64")
self.e = np.random.random((15, 1)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
self.e = np.random.random((150, 1)).astype("float64")
index = np.random.randint(0, 10, (150, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.compute_type = 'MUL'


def TestSumCase5(TestGraphSendERecvSumOp):
class TestSumCase5(TestGraphSendERecvSumOp):
def set_config(self):
self.x = np.random.random((10, 1)).astype("float64")
self.x = np.random.random((100, 1)).astype("float64")
self.e = np.random.random((15, 20)).astype("float64")
index = np.random.randint(0, 10, (15, 2)).astype(np.int64)
index = np.random.randint(0, 100, (15, 2)).astype(np.int64)
self.src_index = index[:, 0]
self.dst_index = index[:, 1]
self.compute_type = 'MUL'


def get_broadcast_shape(shp1, shp2):
pad_shp1, pad_shp2 = shp1, shp2
if len(shp1) > len(shp2):
pad_shp2 = [1, ] * (len(shp1) - len(shp2)) + shp2
elif len(shp1) < len(shp2):
pad_shp1 = [1, ] * (len(shp2) - len(shp1)) + shp1
for d1, d2 in zip(pad_shp1, pad_shp2):
if d1 != d2 and d1 != 1 and d2 != 1:
raise ValueError
rst = [max(d1, d2) for d1, d2 in zip(pad_shp1, pad_shp2)]
return rst


def compute_graph_send_e_recv_for_sum(inputs, attributes):
x = inputs['X']
e = inputs['E']
src_index = inputs['Src_index']
dst_index = inputs['Dst_index']
compute_type = attributes['compute_type']

gather_x = x[src_index]
out_shp = [x.shape[0], ] + get_broadcast_shape(x.shape[1:], e.shape[1:])
results = np.zeros(out_shp, dtype=x.dtype)

# Calculate forward output
if compute_type == 'ADD':
x_compute_e = gather_x + e
elif compute_type == 'MUL':
x_compute_e = gather_x * e
for index, s_id in enumerate(dst_index):
results[s_id, :] += x_compute_e[index, :]
return results

1 comment on commit f961f9b

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on f961f9b Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #43174 Commit ID: f961f9b contains failed CI.

🔹 Failed: PR-CI-APPROVAL

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Windows-Inference

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-NPU

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-ROCM-Compile

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Windows-OPENBLAS

Unknown Failed
Unknown Failed

Please sign in to comment.