Skip to content

Commit

Permalink
complete 2D CSR sparse transpose op
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 8, 2022
1 parent c6c33ff commit f36cc0e
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 16 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/core/sparse_csr_tensor.h
Expand Up @@ -190,7 +190,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 6]]
dims_ = (4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
non_zero_crows_ = [0, 1, 3, 4, 6]
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
*/
Expand All @@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
[0, 0, 4, 0],
[0, 5, 0, 0]]]
dims_ = (2, 4, 4)
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
*/
Expand Down
136 changes: 122 additions & 14 deletions paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#pragma once

#include <unordered_set>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
Expand Down Expand Up @@ -219,46 +219,154 @@ void TransposeCooKernel(const Context& dev_ctx,
DenseTensor* out_indices = out->mutable_indices();
DenseTensor* out_values = out->mutable_non_zero_elements();

int64_t* x_indices_data = x_indices.data<int64_t>();
const int64_t* x_indices_data = x_indices.data<int64_t>();
int64_t* out_indices_data = out_indices->data<int64_t>();
int64_t x_nnz = x.nnz();
std::vector<int> shape;
for (int64_t i = 0; i < dims.size(); ++i) {
for (int64_t j = 0; j < x_nnz; ++j) {
out_indices_data[j + i * x_nnz] = x_indices_data[j + dims[i] * x_nnz];
}
shape.push_back()
}

DDim out_ddim(x.dims());
out_ddim.transpose(dims);

DDim out_dims(x.dims());
out_dims.transpose(dims);
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
out->Resize(out_ddim, x.sparse_dim(), x_nnz);
out->Resize(out_dims, x.sparse_dim(), x_nnz);
}

template <typename T, typename Context>
void TransposeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& dims,
SparseCsrTensor* out) {
out->set_dims(x.dims());

int n_dim = dims.size();
DDim out_dims(x.dims());
out_dims.transpose(dims);
out->set_dims(out_dims);
out->Resize(out_dims, x.nnz());
const DenseTensor& x_crows = x.crows();
const DenseTensor& x_cols = x.cols();
const DenseTensor& x_values = x.non_zero_elements();
DenseTensor* out_crows = out->mutable_crows();
DenseTensor* out_cols = out->mutable_cols();
DenseTensor* out_values = out->mutable_non_zero_elements();

*out_crows = x_crows;
*out_cols = x_cols;
// return a copy of x
if (dims[0] == 0 && dims[1] == 1 && (n_dim == 2 || dims[2] == 2)) {
*out_crows = x_crows;
*out_cols = x_cols;
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
return;
}

int* out_crows_data = out_crows->data<int>();
int* out_cols_data = out_cols->data<int>();
T* out_values_data = out_values->data<T>();
const int* x_crows_data = x_crows.data<int>();
const int* x_cols_data = x_cols.data<int>();
const T* x_values_data = x_values.data<T>();

phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
out->Resize(phi::make_ddim(shape), x_values.dims()[0]);
if (n_dim == 2) { // dims == {1, 0}
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_dims[0]; ++i) {
out_crows_data[i] = 0;
}
out_crows_data[out_dims[0]] = x.nnz();
for (int i = 0; i < x.nnz(); ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
for (int i = 1; i < out_dims[0]; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_set<int> cols_ptr;
for (int i = 0; i < x.dims()[0]; ++i) {
int start = x_crows_data[i];
int end = x_crows_data[i + 1];
for (int j = start; j < end; ++j) {
int jj = x_cols_data[j];
int jjj = out_crows_data[jj];
int jjj_ptr = jjj + cols_ptr.count();
out_cols_data[jjj_ptr] = i;
out_values_data[jjj_ptr] = x_values_data[j];
cols_ptr.insert(jjj);
}
}
} else { // n_dim == 3
for (int k = 0; k < out_dims[0]; ++k) {
if (dims[0] == 0) { // dims == {0, 2, 1}
int out_n_rows = out_dims[1];
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
out_crows_data[out_n_rows] = x_crows_data[x.dims()[1]];
for (int i = 0; i < out_crows_data[out_n_rows]; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
for (int i = 1; i < out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}
// compute out_cols_data and out_values_data by out_crows_data and x
std::unordered_set<int> cols_ptr;
for (int i = 0; i < x.dims()[1]; ++i) {
int start = x_crows_data[i];
int end = x_crows_data[i + 1];
for (int j = start; j < end; ++j) {
int jj = x_cols_data[j];
int jjj = out_crows_data[jj];
int jjj_ptr = jjj + cols_ptr.count();
out_cols_data[jjj_ptr] = i;
out_values_data[jjj_ptr] = x_values_data[j];
cols_ptr.insert(jjj);
}
}
// x offset
x_crows_data += x.dims()[1] + 1;
x_cols_data += x_crows_data[x.dims()[1]];
x_values_data += x_crows_data[x.dims()[1]];
} else if (dims[0] == 1) {
int out_n_rows = out_dims[1];
// compute out_crows_data by x_cols_data
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
// out_crows_data[out_n_rows] = x_crows_data[x.dims()[1]];
int x_cols_offset = 0;
int out_cols_offset = 0;
for (int i = 0; i < x.dims()[0]; ++i) {
int x_crows_index = i * (x.dims()[1] + 1);
int start = x_crows_data[x_crows_index];
int end = x_crows_data[x_crows_index + 1];
out_crows_data[i] = end - start;
for (int j = start; j < end; ++j) {
out_cols_data[j - start] = x_cols_data[x_cols_offset + j];
out_values_data[j - start] = x_values_data[x_cols_offset + j];
x_cols_offset += x_crows_data[x_crows_index + x.dims()[1]];
out_cols_offset += out_crows_data[... + out_dims[1]];
}
}

for (int i = 0; i < out_crows_data[out_n_rows]; ++i) {
int j = x_cols_data[i];
out_crows_data[j + 1]++;
}
for (int i = 1; i < out_n_rows; ++i) {
out_crows_data[i] += out_crows_data[i - 1];
}

// x offset
x_crows_data += 1;
} else {
}
// out offset
out_crows_data += out_dims[1] + 1;
out_cols_data += x_crows_data[out_dims[1]];
out_values_data += x_crows_data[out_dims[1]];
}
}
}

} // namespace sparse
Expand Down

0 comments on commit f36cc0e

Please sign in to comment.