Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 25, 2022
1 parent 8bf4799 commit 535cb46
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 84 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/sparse_backward.yaml
Expand Up @@ -393,8 +393,8 @@
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
func : transpose_coo_grad {sparse_coo -> sparse_coo},
transpose_csr_grad {sparse_csr -> sparse_csr}

- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
Expand Down
64 changes: 0 additions & 64 deletions paddle/phi/infermeta/sparse/unary.cc
Expand Up @@ -32,69 +32,5 @@ void ValuesInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_layout(x.layout());
}

void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
auto x_dims = x.dims();
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();

PADDLE_ENFORCE_EQ(
x_rank,
axis_size,
errors::InvalidArgument("The input tensor's dimension "
"should be equal to the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank,
axis_size));

std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE_GE(
axis[i],
0,
errors::InvalidArgument("The axis should be greater than or equal to 0."
"But received %d of axis[%d]",
axis[i],
i));

PADDLE_ENFORCE_EQ(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
true,
errors::InvalidArgument(
"Each element of Attribute axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received axis[%d] is %d, axis_size is %d, "
"count[axis[%d]] is %d",
i,
axis[i],
axis_size,
i,
count[axis[i]]));
}

phi::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; ++i) {
out_dims[i] = x_dims[axis[i]];
}

out->set_dims(out_dims);
out->set_dtype(x.dtype());
}

void TransposeGradInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}

TransposeInferMeta(x, reversed_axis, out);
}

} // namespace sparse
} // namespace phi
9 changes: 2 additions & 7 deletions paddle/phi/kernels/sparse/cpu/transpose_kernel.cc
Expand Up @@ -16,11 +16,10 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {
Expand Down Expand Up @@ -54,17 +53,14 @@ void TransposeCsrKernel(const Context& dev_ctx,
const std::vector<int>& perm,
SparseCsrTensor* out) {
unsigned int n_dim = perm.size();

const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& x_values = x.non_zero_elements();

const DenseTensor& x_values = x.values();
// return a copy of x
if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) {
out->SetMember(x_crows, x_cols, x_values, x.dims());
return;
}

// create out sparse tensor
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_crows;
Expand All @@ -77,7 +73,6 @@ void TransposeCsrKernel(const Context& dev_ctx,
DenseTensor out_cols = EmptyLike<int64_t, Context>(dev_ctx, x.cols());
DenseTensor out_values = EmptyLike<T, Context>(dev_ctx, x.values());
out->SetMember(out_crows, out_cols, out_values, out_dims);

// transpose by two stages
if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0}
SparseCsrTensor temp;
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/sparse/gpu/transpose_grad_kernel.cu
Expand Up @@ -42,7 +42,6 @@ void TransposeCooGradKernel(const Context& dev_ctx,

template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/sparse/gpu/transpose_kernel.cu
Expand Up @@ -16,9 +16,10 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"

namespace phi {
namespace sparse {

Expand Down
117 changes: 108 additions & 9 deletions paddle/phi/tests/kernels/test_sparse_transpose_dev_api.cc
Expand Up @@ -22,16 +22,16 @@ limitations under the License. */
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
namespace tests {

TEST(DEV_API, sparse_transpose) {
TEST(DEV_API, sparse_transpose_coo) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
Expand All @@ -45,21 +45,120 @@ TEST(DEV_API, sparse_transpose) {

DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::FLOAT32, {3, 2, 2}, DataLayout::NCHW));
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_coo = sparse::DenseToCoo<float>(dev_ctx_cpu, dense_x, 3);
auto sparse_out =
sparse::TransposeCoo<float>(dev_ctx_cpu, sparse_coo, {2, 1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(DataType::FLOAT32, {2, 2, 3}, DataLayout::NCHW));
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out);

// backward
DenseTensor dense_grad_x = phi::EmptyLike<float>(dev_ctx_cpu, dense_out);
TransposeGradKernel<float>(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x);
SparseCooTensor sparse_grad_x;
sparse::EmptyLikeCooKernel<float>(dev_ctx_cpu, sparse_coo, &sparse_grad_x);

SparseCooTensor sparse_out_grad(
sparse_coo.indices(), sparse_coo.values(), {2, 2, 3});
sparse::TransposeCooGradKernel<float>(
dev_ctx_cpu, sparse_out_grad, {2, 1, 0}, &sparse_grad_x);
}

TEST(DEV_API, sparse_transpose_csr_case1) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());

DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);

auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {2, 1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {2, 1, 0}, &dense_out);

for (int i = 0; i < dense_out.numel(); ++i) {
ASSERT_EQ(
dense_out.data<float>()[i],
sparse::CooToDense<float>(dev_ctx_cpu, sparse_out).data<float>()[i]);
}
// backward
DenseTensor dense_grad_x = phi::EmptyLike<float>(dev_ctx_cpu, dense_out);
TransposeGradKernel<float>(dev_ctx_cpu, dense_out, {2, 1, 0}, &dense_grad_x);
SparseCsrTensor sparse_grad_x;
sparse::EmptyLikeCsrKernel<float>(dev_ctx_cpu, sparse_csr, &sparse_grad_x);
sparse::TransposeCsrGradKernel<float>(
dev_ctx_cpu, sparse_out, {2, 1, 0}, &sparse_grad_x);
}

TEST(DEV_API, sparse_transpose_csr_case2) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());

DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 2, 2}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);

auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {1, 2, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({2, 2, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {1, 2, 0}, &dense_out);
}

TEST(DEV_API, sparse_transpose_csr_case3) {
std::vector<float> data = {0, -1, 0, 2, 0, 0, -3, 0, 4, 5, 0, 0};
phi::CPUContext dev_ctx_cpu;
dev_ctx_cpu.SetAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx_cpu.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());

DenseTensor dense_x = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({3, 4}), DataLayout::NCHW));
memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
auto sparse_csr = sparse::DenseToCsr<float>(dev_ctx_cpu, dense_x);

auto sparse_out =
sparse::TransposeCsr<float>(dev_ctx_cpu, sparse_csr, {1, 0});
DenseTensor dense_out = phi::Empty(
dev_ctx_cpu,
DenseTensorMeta(
DataType::FLOAT32, phi::make_ddim({4, 3}), DataLayout::NCHW));
TransposeKernel<float>(dev_ctx_cpu, dense_x, {1, 0}, &dense_out);
}

} // namespace tests
Expand Down

0 comments on commit 535cb46

Please sign in to comment.