diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index de49f6f27fe36..40b646cb38996 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -260,6 +260,17 @@ func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} +- backward_op : reshape_grad + forward : reshape(Tensor x, IntArray shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : reshape_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + reshape_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + - backward_op : scale_grad forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, float scale) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 1d7a4c0bafe53..984686ce80746 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -469,3 +469,24 @@ transpose_csr{sparse_csr -> sparse_csr} layout : x backward : transpose_grad + +- op : sync_batch_norm + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + infer_meta : + func : BatchNormInferMeta + kernel : + func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} + data_type : x + backward : sync_batch_norm_grad + +- op : reshape + args : (Tensor x, IntArray shape) + output : Tensor(out) + infer_meta : + func : ReshapeInferMeta + kernel : + func : reshape_coo{sparse_coo -> sparse_coo}, + reshape_csr{sparse_csr -> sparse_csr} + layout : x + backward : reshape_grad diff --git a/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc new file mode 100644 index 0000000000000..fc843f81c31ee --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/reshape_grad_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCooKernel(dev_ctx, dout, x_shape, dx); +} + +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCsrKernel(dev_ctx, dout, x_shape, dx); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr_grad, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrGradKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc b/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc new file mode 100644 index 0000000000000..4f16515666810 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/reshape_kernel.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" +#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" + +namespace phi { +namespace sparse { + +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out) { + // TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims + int64_t x_nnz = x.nnz(); + + // Use DDim::reshape to handle -1 and 0 in the argument "shape" + std::vector new_shape(shape.GetData().begin(), shape.GetData().end()); + phi::DDim out_dims = x.dims().reshape(new_shape); + // get sparse part dimensions of x and out + std::vector x_sparse_part_dims; + std::vector out_sparse_part_dims; + for (int i = 0; i < x.sparse_dim(); ++i) { + x_sparse_part_dims.push_back(x.dims()[i]); + } + for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { + out_sparse_part_dims.push_back(out_dims[i]); + } + DenseTensor out_indices = Empty( + dev_ctx, {static_cast(out_sparse_part_dims.size()), x_nnz}); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of indices + const DenseTensor& x_indices = x.indices(); + const auto* x_indices_data = x_indices.data(); + auto* out_indices_data = out_indices.data(); + + const phi::DDim& x_sparse_part_strides = + phi::stride(phi::make_ddim(x_sparse_part_dims)); + const phi::DDim& out_sparse_part_strides = + phi::stride(phi::make_ddim(out_sparse_part_dims)); + int64_t location = 0; + for (int64_t j = 0; j < x_nnz; ++j) { + location = 0; + for (int i = 0; i < x.sparse_dim(); ++i) { + location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; + } + for (size_t i = 0; i < out_sparse_part_dims.size(); ++i) { + out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; + location %= out_sparse_part_strides[i]; + } + } +} + +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out) { + // transform csr format to coo format, and then use coo kernel + const SparseCooTensor x_coo = CsrToCoo(dev_ctx, x); + SparseCooTensor out_coo; + ReshapeCooKernel(dev_ctx, x_coo, shape, &out_coo); + CooToCsrKernel(dev_ctx, out_coo, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr, + CPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index d0016099cd759..dcb4399aa2862 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -329,7 +329,8 @@ PD_REGISTER_KERNEL(csr_to_coo, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(coo_to_csr, CPU, @@ -342,7 +343,8 @@ PD_REGISTER_KERNEL(coo_to_csr, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(dense_to_csr, CPU, diff --git a/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu new file mode 100644 index 0000000000000..bfc81676eb804 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/reshape_grad_kernel.cu @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_grad_kernel.h" +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h" + +namespace phi { +namespace sparse { + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx) { + EmptyLikeCooKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCooKernel(dev_ctx, dout, x_shape, dx); +} + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_grad_kernel.cc +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx) { + EmptyLikeCsrKernel(dev_ctx, x, dx); + phi::IntArray x_shape(phi::vectorize(x.dims())); + ReshapeCsrKernel(dev_ctx, dout, x_shape, dx); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrGradKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu b/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu new file mode 100644 index 0000000000000..6e3a9842e8c30 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/reshape_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/unary_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h" + +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace sparse { + +__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data, + const int num_x_sparse_part_dims, + const int num_out_sparse_part_dims, + const int64_t x_nnz, + const int64_t* x_sparse_part_strides, + const int64_t* out_sparse_part_strides, + int64_t* out_indices_data) { + CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) { + int64_t location = 0; + for (int i = 0; i < num_x_sparse_part_dims; ++i) { + location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i]; + } + for (int i = 0; i < num_out_sparse_part_dims; ++i) { + out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i]; + location %= out_sparse_part_strides[i]; + } + } +} + +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out) { + int64_t x_nnz = x.nnz(); + std::vector new_shape(shape.GetData().begin(), shape.GetData().end()); + phi::DDim out_dims = x.dims().reshape(new_shape); + // get sparse part dimensions of x and out + std::vector x_sparse_part_dims; + std::vector out_sparse_part_dims; + for (int i = 0; i < x.sparse_dim(); ++i) { + x_sparse_part_dims.push_back(x.dims()[i]); + } + for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) { + out_sparse_part_dims.push_back(out_dims[i]); + } + + DenseTensor out_indices = Empty( + dev_ctx, {static_cast(out_sparse_part_dims.size()), x_nnz}); + DenseTensor out_values(x.values()); + out->SetMember(out_indices, out_values, out_dims, x.coalesced()); + + // compute values of out indices + const auto* x_indices_data = x.indices().data(); + auto* out_indices_data = out_indices.data(); + const phi::DDim& x_sparse_part_strides = + phi::stride(phi::make_ddim(x_sparse_part_dims)); + const phi::DDim& out_sparse_part_strides = + phi::stride(phi::make_ddim(out_sparse_part_dims)); + + int64_t *destination_x_sparse_part_strides, + *destination_out_sparse_part_strides; + +#ifdef PADDLE_WITH_HIP + hipMalloc(reinterpret_cast(&destination_x_sparse_part_strides), + sizeof(int64_t) * x_sparse_part_strides.size()); + hipMemcpy(destination_x_sparse_part_strides, + x_sparse_part_strides.Get(), + sizeof(int64_t) * x_sparse_part_strides.size(), + hipMemcpyHostToDevice); + hipMalloc(reinterpret_cast(&destination_out_sparse_part_strides), + sizeof(int64_t) * out_sparse_part_strides.size()); + hipMemcpy(destination_out_sparse_part_strides, + out_sparse_part_strides.Get(), + sizeof(int64_t) * out_sparse_part_strides.size(), + hipMemcpyHostToDevice); +#else + cudaMalloc(reinterpret_cast(&destination_x_sparse_part_strides), + sizeof(int64_t) * x_sparse_part_strides.size()); + cudaMemcpy(destination_x_sparse_part_strides, + x_sparse_part_strides.Get(), + sizeof(int64_t) * x_sparse_part_strides.size(), + cudaMemcpyHostToDevice); + cudaMalloc(reinterpret_cast(&destination_out_sparse_part_strides), + sizeof(int64_t) * out_sparse_part_strides.size()); + cudaMemcpy(destination_out_sparse_part_strides, + out_sparse_part_strides.Get(), + sizeof(int64_t) * out_sparse_part_strides.size(), + cudaMemcpyHostToDevice); +#endif + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1); + ReshapeCooCudaKernel<<>>( + x_indices_data, + x_sparse_part_dims.size(), + out_sparse_part_dims.size(), + x_nnz, + destination_x_sparse_part_strides, + destination_out_sparse_part_strides, + out_indices_data); +} + +// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out) { + // transform csr format to coo format, and then use coo kernel + const SparseCooTensor x_coo = CsrToCoo(dev_ctx, x); + SparseCooTensor out_coo; + ReshapeCooKernel(dev_ctx, x_coo, shape, &out_coo); + CooToCsrKernel(dev_ctx, out_coo, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(reshape_coo, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCooKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} + +PD_REGISTER_KERNEL(reshape_csr, + GPU, + ALL_LAYOUT, + phi::sparse::ReshapeCsrKernel, + phi::dtype::float16, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index c037f6b1b8360..c72a38cd8fd32 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -539,7 +539,8 @@ PD_REGISTER_KERNEL(csr_to_coo, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(coo_to_csr, GPU, @@ -552,7 +553,8 @@ PD_REGISTER_KERNEL(coo_to_csr, int8_t, int16_t, int, - int64_t) {} + int64_t, + bool) {} PD_REGISTER_KERNEL(dense_to_csr, GPU, diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.h b/paddle/phi/kernels/sparse/unary_grad_kernel.h index 933e1967e68c3..b446e1b99ed41 100644 --- a/paddle/phi/kernels/sparse/unary_grad_kernel.h +++ b/paddle/phi/kernels/sparse/unary_grad_kernel.h @@ -89,5 +89,17 @@ void TransposeCsrGradKernel(const Context& dev_ctx, const std::vector& perm, SparseCsrTensor* dx); +template +void ReshapeCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& dout, + SparseCooTensor* dx); + +template +void ReshapeCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& dout, + SparseCsrTensor* dx); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index fb5cd21ed3921..a81e724d1fe48 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -155,5 +157,43 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) { return csr; } +template +void ReshapeCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape, + SparseCooTensor* out); + +template +void ReshapeCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape, + SparseCsrTensor* out); + +template +SparseCooTensor ReshapeCoo(const Context& dev_ctx, + const SparseCooTensor& x, + const phi::IntArray& shape) { + SparseCooTensor coo; + ReshapeCooKernel(dev_ctx, x, shape, &coo); + return coo; +} + +template +SparseCsrTensor ReshapeCsr(const Context& dev_ctx, + const SparseCsrTensor& x, + const phi::IntArray& shape) { + PADDLE_ENFORCE_LE( + 2, + shape.size(), + phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); + PADDLE_ENFORCE_GE( + 3, + shape.size(), + phi::errors::InvalidArgument("size of shape must be equal to 2 or 3")); + SparseCsrTensor csr; + ReshapeCsrKernel(dev_ctx, x, shape, &csr); + return csr; +} + } // namespace sparse } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py new file mode 100644 index 0000000000000..e9ef737f7743d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +class TestReshape(unittest.TestCase): + """ + Test the API paddle.incubate.sparse.reshape on some sparse tensors. + x: sparse, out: sparse + """ + + def check_result(self, x_shape, new_shape, format): + """ + x_shape: original shape + new_shape: new shape + format: "coo" or "csr" + Transform a sparse tensor with shape "x_shape" to + a sparse tensor with shape "new_shape". + Compare the output of paddle.reshape and the output of + paddle.incubate.sparse.reshape. + """ + mask = np.random.randint(0, 2, x_shape) + np_x = np.random.randint(-100, 100, x_shape) * mask + + # check cpu kernel + dense_x = paddle.to_tensor(np_x, place=paddle.CPUPlace()) + dense_x.stop_gradient = False + dense_out = paddle.reshape(dense_x, new_shape) + + if format == "coo": + sp_x = paddle.to_tensor(np_x, + place=paddle.CPUPlace()).to_sparse_coo( + len(x_shape)) + else: + sp_x = paddle.to_tensor(np_x, + place=paddle.CPUPlace()).to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy() * + np_x.astype('bool').astype('int'), + rtol=1e-05) + + # check gpu kernel + if paddle.device.is_compiled_with_cuda(): + dense_x = paddle.to_tensor(np_x, place=paddle.CUDAPlace(0)) + dense_x.stop_gradient = False + dense_out = paddle.reshape(dense_x, new_shape) + + if format == "coo": + sp_x = paddle.to_tensor( + np_x, place=paddle.CUDAPlace(0)).to_sparse_coo(len(x_shape)) + else: + sp_x = paddle.to_tensor( + np_x, place=paddle.CUDAPlace(0)).to_sparse_csr() + sp_x.stop_gradient = False + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + np.testing.assert_allclose(sp_out.to_dense().numpy(), + dense_out.numpy(), + rtol=1e-05) + + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose(sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy() * + np_x.astype('bool').astype('int'), + rtol=1e-05) + + def test_reshape_2d(self): + self.check_result([2, 5], [ + 10, + ], 'coo') + self.check_result([12, 5], [15, 4], 'coo') + + self.check_result([10, 5], [2, 25], 'csr') + self.check_result([9, 8], [18, 4], 'csr') + + def test_reshape_3d(self): + self.check_result([6, 2, 3], [6, 2, 3], 'coo') + self.check_result([6, 2, 3], [2, 3, 3, 2], 'coo') + self.check_result([6, 2, 3], [1, 18, 2], 'coo') + self.check_result([6, 2, 3], [2, 9, 2], 'coo') + self.check_result([6, 2, 3], [2, 1, 18], 'coo') + self.check_result([6, 2, 3], [1, 2, 2, 3, 3], 'coo') + + self.check_result([6, 2, 3], [6, 2, 3], 'csr') + self.check_result([6, 2, 3], [6, 3, 2], 'csr') + self.check_result([6, 2, 3], [2, 6, 3], 'csr') + self.check_result([6, 2, 3], [3, 6, 2], 'csr') + self.check_result([6, 2, 3], [4, 9, 1], 'csr') + self.check_result([6, 2, 3], [12, 1, 3], 'csr') + + def test_reshape_nd(self): + self.check_result([8, 3, 4, 4, 5, 3], [24, 8, 10, 3], 'coo') + self.check_result([3, 4, 4, 5, 7], [1, 12, 2, 5, 14], 'coo') + + def test_reshape_with_zero_or_minus_one_in_new_shape(self): + self.check_result([6, 2, 3], [-1, 0, 3], 'coo') + self.check_result([6, 2, 3], [2, 3, 0, -1], 'coo') + self.check_result([6, 2, 3], [1, -1, 2], 'coo') + self.check_result([6, 2, 3], [-1, 9, 2], 'coo') + self.check_result([6, 2, 3], [2, -1, 18], 'coo') + self.check_result([6, 2, 3], [1, 0, 2, -1, 3], 'coo') + + self.check_result([6, 2, 3], [0, 0, -1], 'csr') + self.check_result([6, 2, 3], [-1, 3, 2], 'csr') + self.check_result([6, 2, 3], [2, -1, 0], 'csr') + self.check_result([6, 2, 3], [-1, 6, 2], 'csr') + self.check_result([6, 2, 3], [-1, 9, 1], 'csr') + self.check_result([6, 2, 3], [-1, 1, 3], 'csr') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 581310fbbd9b6..8b6866fa4da0b 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -35,6 +35,7 @@ from .unary import rad2deg from .unary import expm1 from .unary import transpose +from .unary import reshape from .binary import mv from .binary import matmul @@ -50,35 +51,9 @@ from . import nn __all__ = [ - 'sparse_coo_tensor', - 'sparse_csr_tensor', - 'sin', - 'tan', - 'asin', - 'atan', - 'sinh', - 'tanh', - 'asinh', - 'atanh', - 'sqrt', - 'square', - 'log1p', - 'abs', - 'pow', - 'cast', - 'neg', - 'deg2rad', - 'rad2deg', - 'expm1', - 'mv', - 'matmul', - 'masked_matmul', - 'addmm', - 'add', - 'subtract', - 'transpose', - 'multiply', - 'divide', - 'coalesce', - 'is_same_shape', + 'sparse_coo_tensor', 'sparse_csr_tensor', 'sin', 'tan', 'asin', 'atan', + 'sinh', 'tanh', 'asinh', 'atanh', 'sqrt', 'square', 'log1p', 'abs', 'pow', + 'cast', 'neg', 'deg2rad', 'rad2deg', 'expm1', 'mv', 'matmul', + 'masked_matmul', 'addmm', 'add', 'subtract', 'transpose', 'multiply', + 'divide', 'coalesce', 'is_same_shape', 'reshape' ] diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py index f12a76fe93f84..eac098b2bfc5b 100644 --- a/python/paddle/incubate/sparse/unary.py +++ b/python/paddle/incubate/sparse/unary.py @@ -639,3 +639,60 @@ def expm1(x, name=None): out = paddle.incubate.sparse.expm1(sparse_x) """ return _C_ops.sparse_expm1(x) + + +@dygraph_only +def reshape(x, shape, name=None): + """ + Changes the shape of ``x`` without changing its value, requiring x to be a SparseCooTensor or SparseCsrTensor. + Currently this function can only reshape the sparse dims of ``x`` , but ``shape`` argument must be specified + as the shape of the reshaped tensor. + + Note that if x is a SparseCsrTensor, then len(shape) must be 2 or 3. + + There are some tricks when specifying the target shape. + + - 1. -1 means the value of this dimension is inferred from the total element number of x and remaining dimensions. Thus one and only one dimension can be set -1. + + - 2. 0 means the actual dimension value is going to be copied from the corresponding dimension of x. The indices of 0 in the target shape can not exceed the rank of x. + + Here are some examples to explain it. + + - 1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [6, 8], the reshape operator will transform x into a 2-D tensor with shape [6, 8] and leaving x's data unchanged. + + - 2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [2, 3, -1, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this case, one dimension of the target shape is set to -1, the value of this dimension is inferred from the total element number of x and remaining dimensions. + + - 3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case, besides -1, 0 means the actual dimension value is going to be copied from the corresponding dimension of x. + + Args: + x (Tensor): The input sparse tensor with data type ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``. + shape (list|tuple): Define the target shape. At most one dimension of the target shape can be -1. + The data type is ``int32``. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A reshaped Tensor with the same data type as ``x``. + + Examples: + .. code-block:: python + + import paddle + + x_shape = [6, 2, 3] + new_shape = [1, 0, 2, -1, 3] + format = "coo" + + dense_x = paddle.randint(-100, 100, x_shape) * paddle.randint(0, 2, x_shape) + + if format == "coo": + sp_x = dense_x.to_sparse_coo(len(x_shape)) + else: + sp_x = dense_x.to_sparse_csr() + sp_out = paddle.incubate.sparse.reshape(sp_x, new_shape) + + print(sp_out) + # the shape of sp_out is [1, 2, 2, 3, 3] + + """ + return _C_ops.sparse_reshape(x, shape)