/
transpose_grad_kernel.cc
78 lines (70 loc) · 2.67 KB
/
transpose_grad_kernel.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
namespace phi {
namespace sparse {
std::vector<int> get_cpu_grad_perm(std::vector<int> perm) {
std::vector<int> grad_perm(perm.size());
for (unsigned int i = 0; i < perm.size(); ++i) {
grad_perm[perm[i]] = i;
}
return grad_perm;
}
template <typename T, typename Context>
void TransposeCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& dout,
const std::vector<int>& perm,
SparseCooTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCooKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
template <typename T, typename Context>
void TransposeCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
const std::vector<int>& perm,
SparseCsrTensor* dx) {
std::vector<int> grad_perm = get_cpu_grad_perm(perm);
TransposeCsrKernel<T, Context>(dev_ctx, dout, grad_perm, dx);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(transpose_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCooGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(transpose_csr_grad,
CPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrGradKernel,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}