Skip to content

Commit

Permalink
[cherry-pick 2.4] add sparse api transpose/reshape/is_same_shape (#47076
Browse files Browse the repository at this point in the history
)

新增sparse.is_same_shape、sparse.reshape、sparse.transpose 三个API
  • Loading branch information
zhwesky2010 committed Oct 18, 2022
1 parent 5a44c12 commit 5fef043
Show file tree
Hide file tree
Showing 25 changed files with 1,970 additions and 36 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,15 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_is_same_shape(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
return ToPyObject(self->tensor.shape() == other.shape());
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -1983,6 +1992,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"is_same_shape",
(PyCFunction)(void (*)(void))tensor_method_is_same_shape,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"to_sparse_csr",
(PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS,
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/api/yaml/sparse_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@
func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : reshape_grad
forward : reshape(Tensor x, IntArray shape) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr}

- backward_op : scale_grad
forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, float scale)
Expand Down Expand Up @@ -385,6 +396,17 @@
kernel :
func : coo_to_dense { sparse_coo -> dense }

- backward_op : transpose_grad
forward : transpose(Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_coo_grad {sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr -> sparse_csr}

- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/api/yaml/sparse_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,26 @@
mv_csr{sparse_csr, dense -> dense}
layout : x
backward: mv_grad

- op : transpose
args : (Tensor x, int[] perm)
output : Tensor(out)
infer_meta :
func : TransposeInferMeta
param: [ x, perm ]
kernel :
func : transpose_coo{sparse_coo -> sparse_coo},
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad

- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape_coo{sparse_coo -> sparse_coo},
reshape_csr{sparse_csr -> sparse_csr}
layout : x
backward : reshape_grad
2 changes: 1 addition & 1 deletion paddle/phi/core/sparse_coo_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class SparseCooTensor : public TensorBase,
[0, 0, 0, 0]]
dims_ = (4, 4)
non_zero_elements_ = [[0, 1, 0, 0], [0, 0, 4, 0]]
non_zero_indices_ = [0, 2],
non_zero_indices_ = [[0, 2], [1, 2]]
*/
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/sparse_csr_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 6]]
dims_ = (4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
non_zero_crows_ = [0, 1, 3, 4, 6]
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
*/
Expand All @@ -228,7 +228,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 0]]]
dims_ = (2, 4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
*/
Expand Down
73 changes: 73 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& dout,
SparseCooTensor* dx) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCooKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

template <typename T, typename Context>
void ReshapeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
SparseCsrTensor* dx) {
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
phi::IntArray x_shape(phi::vectorize(x.dims()));
ReshapeCsrKernel<T, Context>(dev_ctx, dout, x_shape, dx);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
117 changes: 117 additions & 0 deletions paddle/phi/kernels/sparse/cpu/reshape_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();

// Use DDim::reshape to handle -1 and 0 in the argument "shape"
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
// get sparse part dimensions of x and out
std::vector<int64_t> x_sparse_part_dims;
std::vector<int64_t> out_sparse_part_dims;
for (int i = 0; i < x.sparse_dim(); ++i) {
x_sparse_part_dims.push_back(x.dims()[i]);
}
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());

// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();

const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
phi::stride(phi::make_ddim(out_sparse_part_dims));
int64_t location = 0;
for (int64_t j = 0; j < x_nnz; ++j) {
location = 0;
for (int i = 0; i < x.sparse_dim(); ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
}
for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
}
}
}

template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const phi::IntArray& shape,
SparseCsrTensor* out) {
// transform csr format to coo format, and then use coo kernel
const SparseCooTensor x_coo = CsrToCoo<T, Context>(dev_ctx, x);
SparseCooTensor out_coo;
ReshapeCooKernel<T, Context>(dev_ctx, x_coo, shape, &out_coo);
CooToCsrKernel<T, Context>(dev_ctx, out_coo, out);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(reshape_coo,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCooKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(reshape_csr,
CPU,
ALL_LAYOUT,
phi::sparse::ReshapeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
6 changes: 4 additions & 2 deletions paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(coo_to_csr,
CPU,
Expand All @@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
bool) {}

PD_REGISTER_KERNEL(dense_to_csr,
CPU,
Expand Down

0 comments on commit 5fef043

Please sign in to comment.