Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 3rd No.22 ] add paddle.incubate.sparse.reshape #46694

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
b2d2ea5
add sparse reshape
OccupyMars2025 Sep 20, 2022
5277531
change the dtype in all test cases to int64
OccupyMars2025 Sep 21, 2022
ab3e871
just one test case
OccupyMars2025 Sep 21, 2022
a8a4960
modify comments
OccupyMars2025 Sep 22, 2022
8eb27b8
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Sep 22, 2022
234122f
Update test_sparse_reshape_op.py
OccupyMars2025 Sep 22, 2022
64f98b0
chang the type of "shape" from vector<int64_t> to IntArray
OccupyMars2025 Sep 23, 2022
0f4660d
check whether sp_out.to_dense() is the cause of error
OccupyMars2025 Sep 23, 2022
4761494
print sp_out
OccupyMars2025 Sep 23, 2022
aa72cc7
Update reshape_kernel.cc
OccupyMars2025 Sep 23, 2022
c957c8d
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Sep 23, 2022
fb434fd
use numpy to generate the equal paddle tensor
OccupyMars2025 Sep 23, 2022
281a3d4
just check dense_tensor.numpy()
OccupyMars2025 Sep 23, 2022
64a3503
check cpu and cuda versions
OccupyMars2025 Sep 23, 2022
84f51db
Update test_sparse_reshape_op.py
OccupyMars2025 Sep 24, 2022
90bfea3
supply all test cases for cpu forward coo kernel
OccupyMars2025 Sep 24, 2022
7e80110
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Sep 29, 2022
8ad2f55
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Sep 29, 2022
21d3538
test forward coo cuda kernel
OccupyMars2025 Sep 29, 2022
687778a
change configuration of cuda kernel
OccupyMars2025 Sep 29, 2022
11ee8c3
keep only one test case
OccupyMars2025 Sep 29, 2022
de6d903
test coo cpu kernel (forward and backward)
OccupyMars2025 Sep 29, 2022
497e27a
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Sep 29, 2022
8b09f24
row major or column major ???
OccupyMars2025 Oct 1, 2022
4420f4f
test cuda coo forward kernel
OccupyMars2025 Oct 1, 2022
953b88e
Merge branch 'hackathon-3rd-task22-add-paddle.incubate.sparse.reshape…
OccupyMars2025 Oct 1, 2022
d74bda9
complete declaration and registration
OccupyMars2025 Oct 1, 2022
17ec7c3
Update __init__.py
OccupyMars2025 Oct 1, 2022
ca08918
rebuild
OccupyMars2025 Oct 1, 2022
cdd78a1
retrigger CI
OccupyMars2025 Oct 1, 2022
471c648
add cudaMalloc and cudaMemcpy in ReshapeCooKernel and change back …
OccupyMars2025 Oct 1, 2022
5f33291
midify minor error
OccupyMars2025 Oct 1, 2022
f1a9d9b
test only cpu coo forward kernel
OccupyMars2025 Oct 2, 2022
6aa3061
add all test cases for coo forward kernel (both cpu and gpu)
OccupyMars2025 Oct 2, 2022
499b78a
test all forward kernels (coo, csr; cpu, gpu)
OccupyMars2025 Oct 2, 2022
2e91335
add all test cases for all kinds of kernels
OccupyMars2025 Oct 2, 2022
cc21e67
just retrigger CI
OccupyMars2025 Oct 3, 2022
1d3ed6d
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 7, 2022
b4e6d2f
Update sparse_ops.yaml
OccupyMars2025 Oct 10, 2022
b5d6dbc
Update sparse_ops.yaml
OccupyMars2025 Oct 10, 2022
46247a7
Update sparse_ops.yaml
OccupyMars2025 Oct 10, 2022
0661bd2
resolve conflicts
OccupyMars2025 Oct 10, 2022
c01cb39
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 10, 2022
3df911c
Update sparse_ops.yaml
OccupyMars2025 Oct 10, 2022
eaca3a2
don't specify tensor place
OccupyMars2025 Oct 10, 2022
d80cf1a
new shape has -1 or 0 in it
OccupyMars2025 Oct 10, 2022
1da6d60
Update unary_grad_kernel.h
OccupyMars2025 Oct 10, 2022
bbf4b48
correct lvalue error
OccupyMars2025 Oct 10, 2022
dd2961e
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 12, 2022
6db2391
code style
OccupyMars2025 Oct 12, 2022
b01ebb9
Update sparse_backward.yaml
OccupyMars2025 Oct 12, 2022
dd30c2e
Update sparse_ops.yaml
OccupyMars2025 Oct 12, 2022
64d7de2
Update unary_kernel.h
OccupyMars2025 Oct 12, 2022
01ec505
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 12, 2022
118c429
Update unary.py
OccupyMars2025 Oct 12, 2022
4a510d2
Update sparse_backward.yaml
OccupyMars2025 Oct 12, 2022
f011729
Update unary.py
OccupyMars2025 Oct 12, 2022
9db2c6a
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 12, 2022
659c034
code style
OccupyMars2025 Oct 12, 2022
213145b
code style
OccupyMars2025 Oct 13, 2022
0d77ea7
code style
OccupyMars2025 Oct 13, 2022
eb33c53
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task22-add-pad…
OccupyMars2025 Oct 14, 2022
67cfa50
Update unary.py
OccupyMars2025 Oct 14, 2022
32cd23c
specify tensor place explicitly
OccupyMars2025 Oct 14, 2022
a96baa6
do not use numpy array
OccupyMars2025 Oct 14, 2022
8c7f144
use numpy array in unit test again
OccupyMars2025 Oct 14, 2022
518038c
modify example code in docstring
OccupyMars2025 Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/sparse_backward.yaml
Expand Up @@ -272,6 +272,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/sparse_ops.yaml
Expand Up @@ -489,3 +489,14 @@
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad

- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"func : ReshapeInferMeta" , I have used the infermeta function for dense tensor, so there is no need to write new ones.

kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
Copy link
Contributor

@jeff41404 jeff41404 Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the design of reshape_csr needs to be added to rfc, also ReshapeCsrKernel and ReshapeCsrGradKernel below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean I need to modify the RFC to make the content of the RFC agree with my actual implementation ? Ok, I will do it.

layout : x
backward : reshape_grad
73 changes: 73 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc
@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
117 changes: 117 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_kernel.cc
@@ -0,0 +1,117 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();

// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());

// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();

const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}

template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
Expand Up @@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(coo_to_csr,
CPU,
Expand All @@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(dense_to_csr,
CPU,
Expand Down
77 changes: 77 additions & 0 deletions paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc
template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}