Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Sparse】add new API/OP(csr->csr) of SparseTensor softmax #43475

Merged
merged 2 commits into from Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 91 additions & 0 deletions paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc
@@ -0,0 +1,91 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"

#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace plt = paddle::platform;

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SoftmaxCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& out,
const SparseCsrTensor& dout,
int axis,
SparseCsrTensor* dx) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, dout, dx);

auto out_dim = out.dims();
int rows = 1;
for (int i = 0; i < out_dim.size() - 1; ++i) {
rows *= out_dim[i];
}

const DenseTensor& out_crows = out.non_zero_crows();
const DenseTensor& out_values = out.non_zero_elements();
const DenseTensor& dout_values = dout.non_zero_elements();
DenseTensor* dx_values = dx->mutable_non_zero_elements();

int row_first = 0;
int row_nnz = 0;
const T* out_data = out_values.data<T>();
const T* dout_data = dout_values.data<T>();
T* dx_data = dx_values->data<T>();

// dx = (dout - sum(dout * out)) * out
PD_VISIT_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
const data_t* out_crows_data = out_crows.data<data_t>();
for (int i = 0; i < rows; ++i) {
row_first = static_cast<int>(out_crows_data[i]);
row_nnz = static_cast<int>(out_crows_data[i + 1] - out_crows_data[i]);

out_data = out_data + row_first;
dout_data = dout_data + row_first;
dx_data = dx_data + row_first;

T sum = 0;
phi::funcs::vec_mul_reduce<T, plt::avx>(
row_nnz, dout_data, out_data, &sum);
phi::funcs::vec_add_bias<T, plt::avx>(
row_nnz, static_cast<T>(-1) * sum, dout_data, dx_data);
phi::funcs::vec_mul<T, plt::avx>(row_nnz, dx_data, out_data, dx_data);
}
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(softmax_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
92 changes: 92 additions & 0 deletions paddle/phi/kernels/sparse/cpu/softmax_kernel.cc
@@ -0,0 +1,92 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/softmax_kernel.h"

#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"

namespace plt = paddle::platform;

namespace phi {
namespace sparse {

template <typename T, typename Context>
void SoftmaxCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
int axis,
SparseCsrTensor* out) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);

auto x_dim = x.dims();
int row_number = 1;
for (int i = 0; i < x_dim.size() - 1; ++i) {
row_number *= x_dim[i];
}

const DenseTensor& x_crows = x.non_zero_crows();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_values = out->mutable_non_zero_elements();

int row_first = 0;
int row_nnz = 0;
T row_max_val = 0;
const T* x_data = x_values.data<T>();
T* out_data = out_values->data<T>();

// out = exp(x-x_max) / sum( exp(x-x_max ))
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "CsrSoftmaxKernel", ([&] {
const data_t* x_crows_data = x_crows.data<data_t>();
for (int i = 0; i < row_number; ++i) {
row_first = static_cast<int>(x_crows_data[i]);
row_nnz = static_cast<int>(x_crows_data[i + 1] - x_crows_data[i]);

x_data = x_data + row_first;
out_data = out_data + row_first;

row_max_val = *std::max_element(x_data, x_data + row_nnz);
phi::funcs::vec_add_bias<T, plt::avx>(
row_nnz, static_cast<T>(-1) * row_max_val, x_data, out_data);

phi::funcs::vec_exp<T>(row_nnz, out_data, out_data);

T sum = 0;
phi::funcs::vec_sum<T, plt::avx>(row_nnz, out_data, &sum);
phi::funcs::vec_scal<T, plt::avx>(
row_nnz, static_cast<T>(1) / sum, out_data, out_data);
}
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(softmax_csr,
CPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
78 changes: 78 additions & 0 deletions paddle/phi/kernels/sparse/empty_kernel.cc
@@ -0,0 +1,78 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/empty_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void EmptyLikeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out) {
const DenseTensor& x_crows = x.non_zero_crows();
const DenseTensor& x_cols = x.non_zero_cols();
const DenseTensor& x_values = x.non_zero_elements();

DenseTensor* out_crows = out->mutable_non_zero_crows();
DenseTensor* out_cols = out->mutable_non_zero_cols();
DenseTensor* out_values = out->mutable_non_zero_elements();

out->set_dims(x.dims());
phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows);
phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols);

out_values->Resize(x_values.dims());
dev_ctx.template Alloc<T>(out_values);
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(empty_like_csr,
CPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(empty_like_csr,
GPU,
ALL_LAYOUT,
phi::sparse::EmptyLikeCsrKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}
#endif
28 changes: 28 additions & 0 deletions paddle/phi/kernels/sparse/empty_kernel.h
@@ -0,0 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/sparse_csr_tensor.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void EmptyLikeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out);

} // namespace sparse
} // namespace phi
102 changes: 102 additions & 0 deletions paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
@@ -0,0 +1,102 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT = int>
__global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
const T* out_values,
const T* dout_values,
T* dx_values,
int row_number) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= row_number) return;
int row_first = static_cast<int>(out_crows[row]);
int row_nnz = static_cast<int>(out_crows[row + 1] - out_crows[row]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + warpSize - 1) / warpSize;

T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
if (idx >= row_nnz) break;

dx_values[row_first + idx] =
(dout_values[row_first + idx] - sum) * out_values[row_first + idx];
}
}

template <typename T, typename Context>
void SoftmaxCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& out,
const SparseCsrTensor& dout,
int axis,
SparseCsrTensor* dx) {
PADDLE_ENFORCE_EQ(axis,
-1,
phi::errors::Unimplemented(
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"));
EmptyLikeCsrKernel<T, Context>(dev_ctx, dout, dx);

auto out_dim = out.dims();
int row_number = 1;
for (int i = 0; i < out_dim.size() - 1; ++i) {
row_number *= out_dim[i];
}

dim3 grid((row_number + 3) / 4);
dim3 block(32, 4);

PD_VISIT_INTEGRAL_TYPES(
out.non_zero_crows().dtype(), "SoftmaxCsrGradKernel", ([&] {
SoftmaxGradGpuKernel<T, data_t><<<grid, block, 0, dev_ctx.stream()>>>(
out.non_zero_crows().data<data_t>(),
out.non_zero_elements().data<T>(),
dout.non_zero_elements().data<T>(),
dx->mutable_non_zero_elements()->data<T>(),
row_number);
}));
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(softmax_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SoftmaxCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}