Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.21】为 Paddle 新增 paddle.incubate.sparse.transpose 稀疏 API #45849

Merged
merged 40 commits into from Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9f403dc
add sparse transpose op
zrr1999 Sep 5, 2022
2275d63
add sparse transpose op
zrr1999 Sep 7, 2022
c6c33ff
Merge branch 'develop' into sparse_transpose
zrr1999 Sep 7, 2022
f36cc0e
complete 2D CSR sparse transpose op
zrr1999 Sep 8, 2022
5158792
add more unitests
zrr1999 Sep 8, 2022
10f1682
fix bugs
zrr1999 Sep 8, 2022
1a7684a
complete 3D CSR sparse transpose op
zrr1999 Sep 9, 2022
1bcdfac
fix bugs
zrr1999 Sep 9, 2022
553a792
add grad
zrr1999 Sep 9, 2022
744667a
fix bugs
zrr1999 Sep 9, 2022
0135acc
Merge branch 'develop' into sparse_transpose
zrr1999 Sep 9, 2022
c6060b9
optimize unittest
zrr1999 Sep 9, 2022
4d0e052
fix bugs
zrr1999 Sep 9, 2022
129a2f0
add cuda impl
zrr1999 Sep 12, 2022
2116140
fix bugs
zrr1999 Sep 13, 2022
8116fe7
Merge branch 'develop' into sparse_transpose
zrr1999 Sep 14, 2022
2b156e3
add ops
zrr1999 Sep 14, 2022
00f18ed
Merge branch 'develop' into sparse_transpose
zrr1999 Sep 14, 2022
8d1e718
add ops
zrr1999 Sep 14, 2022
0c68706
reduce shape
zrr1999 Sep 14, 2022
1592b92
fix bugs
zrr1999 Sep 15, 2022
70a4b79
replace dims to perm
zrr1999 Sep 17, 2022
2053889
add hip impl
zrr1999 Sep 17, 2022
152d0b8
replace dims to perm
zrr1999 Sep 17, 2022
ecaea58
move transpose op to transpose file
zrr1999 Sep 17, 2022
c92e076
restore unary file
zrr1999 Sep 17, 2022
c5ec3fb
modified docs
zrr1999 Sep 20, 2022
b7090dd
Merge branch 'PaddlePaddle:develop' into sparse_transpose
zrr1999 Sep 21, 2022
fb59429
add infer_meta
zrr1999 Sep 21, 2022
003c044
fix bugs
zrr1999 Sep 21, 2022
5cc042c
optimize cuda
zrr1999 Sep 21, 2022
acdd53f
optimize cuda
zrr1999 Sep 22, 2022
3884ff0
optimize cuda
zrr1999 Sep 22, 2022
43be57e
share memory
zrr1999 Sep 22, 2022
046607d
share memory
zrr1999 Sep 23, 2022
64b0343
Merge branch 'develop' into sparse_transpose
zrr1999 Sep 23, 2022
2157b31
share memory
zrr1999 Sep 23, 2022
8bf4799
remove parameter x of grad function
zrr1999 Sep 23, 2022
535cb46
fix bugs
zrr1999 Sep 23, 2022
6522826
fix bugs
zrr1999 Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/sparse_backward.yaml
Expand Up @@ -385,6 +385,17 @@
kernel :
func : coo_to_dense { sparse_coo -> dense }

- backward_op : transpose_grad
forward : transpose(Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_coo_grad {sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr -> sparse_csr}

- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/sparse_ops.yaml
Expand Up @@ -457,3 +457,15 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad

- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the used " func : TransposeInferMeta" is actually the one for dense tensor. But the TransposeInferMeta for dense tensor is also applicable to sparse tensor. "func : TransposeGradInferMeta" has the same situation. So I submit a PR to delete maybe unused code in paddle\phi\infermeta\sparse\unary.h

#46844

param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad
2 changes: 1 addition & 1 deletion paddle/phi/core/sparse_coo_tensor.h
Expand Up @@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase,
[0, 0, 0, 0]]
dims_ = (4, 4)
non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]]
non_zero_indices_ = [0, 2],
non_zero_indices_ = [[0, 2], [1, 2]]
*/
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/sparse_csr_tensor.h
Expand Up @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 6]]
dims_ = (4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
non_zero_crows_ = [0, 1, 3, 4, 6]
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
*/
Expand All @@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 0]]]
dims_ = (2, 4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
*/
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/sparse/unary.h
Expand Up @@ -24,5 +24,12 @@ void IndicesInferMeta(const MetaTensor& x, MetaTensor* out);

void ValuesInferMeta(const MetaTensor& x, MetaTensor* out);

void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);

void TransposeGradInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
} // namespace sparse
} // namespace phi
78 changes: 78 additions & 0 deletions paddle/phi/kernels/sparse/cpu/transpose_grad_kernel.cc
@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

std::vector<int> get_cpu_grad_perm(std::vector<int> perm) {
std::vector<int> grad_perm(perm.size());
for (unsigned int i = 0; i < perm.size(); ++i) {
grad_perm[perm[i]] = i;
}
return grad_perm;
}

template <typename T, typename Context>
void TransposeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
const std::vector<int>& perm,
SparseCooTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCooKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}

template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCsrKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(transpose_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(transpose_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
228 changes: 228 additions & 0 deletions paddle/phi/kernels/sparse/cpu/transpose_kernel.cc
@@ -0,0 +1,228 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void TransposeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const std::vector<int>& perm,
SparseCooTensor* out) {
// create out sparse tensor
int64_t x_nnz = x.nnz();
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_indices = EmptyLike<int64_t, Context>(dev_ctx, x.indices());
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());

// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
for (unsigned int i = 0; i < perm.size(); ++i) {
for (int64_t j = 0; j < x_nnz; ++j) {
out_indices_data[j + i * x_nnz] = x_indices_data[j + perm[i] * x_nnz];
}
}
}

template <typename T, typename Context>
void TransposeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& perm,
SparseCsrTensor* out) {
unsigned int n_dim = perm.size();
const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& x_values = x.values();
// return a copy of x
if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) {
out->SetMember(x_crows, x_cols, x_values, x.dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用DenseTensor a=b; 触发左值构造,底层会共享Allocation内存,不要直接就完全用同一个DenseTensor,这个可能会引入风险

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已经修改

return;
}
// create out sparse tensor
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_crows;
if (n_dim == 2) {
out_crows = Empty<int64_t, Context>(dev_ctx, {out_dims[0] + 1});
} else {
out_crows =
Empty<int64_t, Context>(dev_ctx, {out_dims[0] * (out_dims[1] + 1)});
}
DenseTensor out_cols = EmptyLike<int64_t, Context>(dev_ctx, x.cols());
DenseTensor out_values = EmptyLike<T, Context>(dev_ctx, x.values());
out->SetMember(out_crows, out_cols, out_values, out_dims);
// transpose by two stages
if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {0, 2, 1}, out);
return;
} else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {1, 0, 2}, out);
return;
} else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0}
SparseCsrTensor temp;
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0, 2}, &temp);
TransposeCsrKernel<T, Context>(dev_ctx, temp, {2, 0, 1}, out);
return;
}

int64_t* out_crows_data = out_crows.data<int64_t>();
int64_t* out_cols_data = out_cols.data<int64_t>();
T* out_values_data = out_values.data<T>();
const int64_t* x_crows_data = x_crows.data<int64_t>();
const int64_t* x_cols_data = x_cols.data<int64_t>();
const T* x_values_data = x_values.data<T>();

int64_t x_nnz = x.nnz();
if (n_dim == 2) { // perm == {1, 0}
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_dims[0]; ++i) {
out_crows_data[i] = 0;
}
for (int i = 0; i < x_nnz; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_dims[0]] = x_nnz;
for (int i = 1; i < out_dims[0]; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_map<int64_t, int> cols_offset;
for (int i = 0; i < x.dims()[0]; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j];
int64_t jjj = out_crows_data[x_cols_j];
if (cols_offset.count(jjj)) {
cols_offset[jjj]++;
} else {
cols_offset[jjj] = 0;
}
int64_t jjj_offset = jjj + cols_offset[jjj];
out_cols_data[jjj_offset] = i;
out_values_data[jjj_offset] = x_values_data[j];
}
}
} else { // n_dim == 3
int out_n_rows = out_dims[1];
int x_n_rows = x.dims()[1];
for (int k = 0; k < out_dims[0]; ++k) {
if (perm[0] == 0) { // perm == {0, 2, 1}
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
for (int i = 0; i < x_crows_data[x_n_rows]; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_n_rows] = x_crows_data[x_n_rows];
for (int i = 1; i < out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_map<int64_t, int> cols_offset;
for (int i = 0; i < x_n_rows; ++i) {
int64_t start = x_crows_data[i];
int64_t end = x_crows_data[i + 1];
for (int64_t j = start; j < end; ++j) {
int64_t x_cols_j = x_cols_data[j];
int64_t jjj = out_crows_data[x_cols_j];
if (cols_offset.count(jjj)) {
cols_offset[jjj]++;
} else {
cols_offset[jjj] = 0;
}
int64_t jjj_offset = jjj + cols_offset[jjj];
out_cols_data[jjj_offset] = i;
out_values_data[jjj_offset] = x_values_data[j];
}
}
// x offset
x_cols_data += x_crows_data[x_n_rows];
x_values_data += x_crows_data[x_n_rows];
x_crows_data += x_n_rows + 1;
} else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2}
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
int x_cols_offset = 0;
int out_cols_index = 0;
for (int i = 0; i < x.dims()[0]; ++i) {
int x_crows_index = i * (x_n_rows + 1);
int start = x_crows_data[x_crows_index + k];
int end = x_crows_data[x_crows_index + 1 + k];
out_crows_data[i + 1] = end - start;
for (int j = start; j < end; ++j) {
out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j];
out_values_data[out_cols_index] = x_values_data[x_cols_offset + j];
out_cols_index++;
}
x_cols_offset += x_crows_data[x_crows_index + x_n_rows];
}
for (int i = 1; i <= out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
}
// out offset
out_cols_data += out_crows_data[out_n_rows];
out_values_data += out_crows_data[out_n_rows];
out_crows_data += out_n_rows + 1;
}
}
}
} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(transpose_coo,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(transpose_csr,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}