Skip to content

Commit

Permalink
review code
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 10, 2022
1 parent 666bd68 commit a480092
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 56 deletions.
50 changes: 25 additions & 25 deletions paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc
Expand Up @@ -356,53 +356,53 @@ void GraphSendUERecvGradOpKernelLaunchHelper(
const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& e,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& compute_type,
const std::string& pool_type,
DenseTensor* x_grad,
DenseTensor* e_grad,
DenseTensor* y_grad,
const DenseTensor* dst_count = nullptr,
const DenseTensor* out = nullptr) {
const int& index_size = dst_index.dims()[0];

ctx.template Alloc<T>(x_grad);
T* x_grad_data = x_grad->data<T>();
ctx.template Alloc<T>(e_grad);
T* e_grad_data = e_grad->data<T>();
ctx.template Alloc<T>(y_grad);
T* y_grad_data = y_grad->data<T>();
const auto& x_dims = x.dims();
const auto& e_dims = e.dims();
int64_t memset_size_x = 1, memset_size_e = 1;
const auto& y_dims = y.dims();
int64_t memset_size_x = 1, memset_size_y = 1;
int64_t slice_size = 1;
for (int i = 0; i < x_dims.size(); i++) {
memset_size_x *= x_dims[i];
if (i > 0) slice_size *= x_dims[i];
}
for (int i = 0; i < e_dims.size(); i++) {
memset_size_e *= e_dims[i];
for (int i = 0; i < y_dims.size(); i++) {
memset_size_y *= y_dims[i];
}
const size_t& memset_bytes_x = memset_size_x * sizeof(T);
const size_t& memset_bytes_e = memset_size_e * sizeof(T);
const size_t& memset_bytes_y = memset_size_y * sizeof(T);
memset(x_grad_data, 0, memset_bytes_x);
memset(e_grad_data, 0, memset_bytes_e);
memset(y_grad_data, 0, memset_bytes_y);

if (index_size == 0) return;

const T* out_grad_data = out_grad.data<T>();
const T* x_data = x.data<T>();
const T* e_data = e.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();

if (pool_type == "SUM" || pool_type == "MEAN") {
CalculateXGrad<Context, T, IndexT>(ctx,
out_grad_data,
x_data,
e_data,
y_data,
out_grad.dims(),
x_dims,
e_dims,
y_dims,
d_index,
s_index,
compute_type,
Expand All @@ -415,37 +415,37 @@ void GraphSendUERecvGradOpKernelLaunchHelper(
out);
CalculateEGrad<T, IndexT>(out_grad_data,
x_data,
e_data,
y_data,
x_dims,
e_dims,
y_dims,
s_index,
d_index,
compute_type,
pool_type,
index_size,
e_grad_data,
y_grad_data,
dst_count);
} else if (pool_type == "MIN" || pool_type == "MAX") {
CalculateXEGradForMinMax<T, IndexT>(out_grad_data,
x_data,
e_data,
y_data,
x_dims,
e_dims,
y_dims,
d_index,
s_index,
compute_type,
pool_type,
index_size,
x_grad_data,
e_grad_data,
y_grad_data,
out);
}
}

template <typename T, typename Context>
void GraphSendUERecvGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& e,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const paddle::optional<DenseTensor>& out,
Expand All @@ -454,34 +454,34 @@ void GraphSendUERecvGradKernel(const Context& ctx,
const std::string& compute_type,
const std::string& pool_type,
DenseTensor* x_grad,
DenseTensor* e_grad) {
DenseTensor* y_grad) {
auto index_type = src_index.dtype();
if (index_type == phi::DataType::INT32) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int32_t>(
ctx,
out_grad,
x,
e,
y,
src_index,
dst_index,
compute_type,
pool_type,
x_grad,
e_grad,
y_grad,
dst_count.get_ptr(),
out.get_ptr());
} else if (index_type == phi::DataType::INT64) {
GraphSendUERecvGradOpKernelLaunchHelper<Context, T, int64_t>(
ctx,
out_grad,
x,
e,
y,
src_index,
dst_index,
compute_type,
pool_type,
x_grad,
e_grad,
y_grad,
dst_count.get_ptr(),
out.get_ptr());
}
Expand Down
42 changes: 21 additions & 21 deletions paddle/phi/kernels/cpu/graph_send_ue_recv_kernel.cc
Expand Up @@ -29,7 +29,7 @@ namespace phi {
template <typename T, typename IndexT, typename ComputeFunctor>
void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast,
const T* x_data,
const T* e_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
Expand All @@ -43,11 +43,11 @@ void GraphSendUERecvSumCpuKernel(const BroadCastInfo& bcast,
IndexT dst = dst_indices[i];
T* out_off = output + dst * bcast.out_len;
const T* x_off = x_data + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
const T* y_off = y_data + i * bcast.r_len;
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], e_off[e_add]);
int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], y_off[y_add]);
if (val != 0) {
#ifdef PADDLE_WITH_MKLML
#pragma omp atomic
Expand All @@ -64,7 +64,7 @@ template <typename T,
typename CmpFunctor>
void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast,
const T* x_data,
const T* e_data,
const T* y_data,
const IndexT* src_indices,
const IndexT* dst_indices,
T* output,
Expand All @@ -80,17 +80,17 @@ void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast,
IndexT dst = dst_indices[i];
T* out_off = output + dst * bcast.out_len;
const T* x_off = x_data + src * bcast.l_len;
const T* e_off = e_data + i * bcast.r_len;
const T* y_off = y_data + i * bcast.r_len;
bool in_set = existed_dst.find(dst) != existed_dst.end();
for (int64_t j = 0; j < bcast.out_len; j++) {
int64_t x_add = bcast.use_bcast ? bcast.l_offset[j] : j;
int64_t e_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], e_off[e_add]);
int64_t y_add = bcast.use_bcast ? bcast.r_offset[j] : j;
T val = cfunctor(x_off[x_add], y_off[y_add]);
#ifdef PADDLE_WITH_MKLML
#pragma omp critical
#endif
if (!in_set) {
out_off[j] += val;
out_off[j] = val;
} else {
out_off[j] = pfunctor(out_off[j], val);
}
Expand All @@ -107,7 +107,7 @@ void GraphSendUERecvMinMaxCpuKernel(const BroadCastInfo& bcast,
template <typename Context, typename T, typename IndexT>
void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
const DenseTensor& x,
const DenseTensor& e,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& compute_type,
Expand Down Expand Up @@ -135,17 +135,17 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
memset(out_data, 0, memset_bytes);

if (index_size == 0) return;
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims());
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), y.dims());
const T* x_data = x.data<T>();
const T* e_data = e.data<T>();
const T* y_data = y.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
const IndexT* d_index = dst_index.data<IndexT>();
if (pool_type == "SUM" || pool_type == "MEAN") {
if (compute_type == "ADD") {
GraphAddFunctor<T> add_functor;
GraphSendUERecvSumCpuKernel<T, IndexT, GraphAddFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand All @@ -155,7 +155,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
GraphMulFunctor<T> mul_functor;
GraphSendUERecvSumCpuKernel<T, IndexT, GraphMulFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand Down Expand Up @@ -187,7 +187,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
GraphAddFunctor<T>,
GraphMinFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand All @@ -201,7 +201,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
GraphMulFunctor<T>,
GraphMinFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand All @@ -218,7 +218,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
GraphAddFunctor<T>,
GraphMaxFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand All @@ -232,7 +232,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
GraphMulFunctor<T>,
GraphMaxFunctor<T>>(bcast_info,
x_data,
e_data,
y_data,
s_index,
d_index,
out_data,
Expand All @@ -246,7 +246,7 @@ void GraphSendUERecvOpKernelLaunchHelper(const Context& ctx,
template <typename T, typename Context>
void GraphSendUERecvKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& e,
const DenseTensor& y,
const DenseTensor& src_index,
const DenseTensor& dst_index,
const std::string& compute_type,
Expand All @@ -259,7 +259,7 @@ void GraphSendUERecvKernel(const Context& ctx,
if (index_type == phi::DataType::INT32) {
GraphSendUERecvOpKernelLaunchHelper<Context, T, int32_t>(ctx,
x,
e,
y,
src_index,
dst_index,
compute_type,
Expand All @@ -270,7 +270,7 @@ void GraphSendUERecvKernel(const Context& ctx,
} else if (index_type == phi::DataType::INT64) {
GraphSendUERecvOpKernelLaunchHelper<Context, T, int64_t>(ctx,
x,
e,
y,
src_index,
dst_index,
compute_type,
Expand Down
3 changes: 0 additions & 3 deletions paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h
Expand Up @@ -16,9 +16,6 @@
#include <thrust/device_vector.h>
#include <thrust/fill.h>

#include <algorithm>
#include <vector>

#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
Expand Down
@@ -1,5 +1,6 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Copyright 2022 The DGL team.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/geometric/message_passing/send_recv.py
Expand Up @@ -41,7 +41,7 @@ def send_u_recv(x,
Given:
X = [[0, 2, 3],
x = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
Expand All @@ -55,7 +55,7 @@ def send_u_recv(x,
Then:
Out = [[0, 2, 3],
out = [[0, 2, 3],
[2, 8, 10],
[1, 4, 5]]
Expand Down Expand Up @@ -176,19 +176,19 @@ def send_ue_recv(x,
This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory
consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index`
to gather the corresponding data, after computing with `y` in different compute types, then use `dst_index` to
to gather the corresponding data, after computing with `y` in different compute types like add/sub/mul/div, then use `dst_index` to
update the corresponding position of output tensor in different pooling types, like sum, mean, max, or min.
Besides, we can use `out_size` to set necessary output shape.
.. code-block:: text
Given:
X = [[0, 2, 3],
x = [[0, 2, 3],
[1, 4, 5],
[2, 6, 7]]
Y = [1, 1, 1]
y = [1, 1, 1]
src_index = [0, 1, 2, 0]
Expand All @@ -202,7 +202,7 @@ def send_ue_recv(x,
Then:
Out = [[1, 3, 4],
out = [[1, 3, 4],
[4, 10, 12],
[2, 5, 6]]
Args:
Expand Down

0 comments on commit a480092

Please sign in to comment.