/
logcumsumexp_grad_impl.h
87 lines (76 loc) · 3.12 KB
/
logcumsumexp_grad_impl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cumsum_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
struct LogGradPositiveFunctor {
HOSTDEVICE T operator()(const T& x) const {
const T kMin = std::numeric_limits<T>::lowest();
return x > 0 ? std::log(x) : kMin;
}
};
template <typename T>
struct LogGradNegativeFunctor {
HOSTDEVICE T operator()(const T& x) const {
const T kMin = std::numeric_limits<T>::lowest();
return x < 0 ? std::log(-x) : kMin;
}
};
template <typename T, typename Context>
void LogcumsumexpGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& d_out,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* d_x) {
// Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py
reverse = !reverse;
dev_ctx.template Alloc<T>(d_x);
auto eigen_x = EigenMatrix<T>::From(x);
auto eigen_out = EigenMatrix<T>::From(out);
auto eigen_d_out = EigenMatrix<T>::From(d_out);
auto& place = *dev_ctx.eigen_device();
DenseTensor output_pos;
output_pos.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_pos);
auto eigen_output_pos = EigenMatrix<T>::From(output_pos);
DenseTensor output_neg;
output_neg.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_neg);
auto eigen_output_neg = EigenMatrix<T>::From(output_neg);
DenseTensor tmp;
tmp.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&tmp);
auto eigen_tmp = EigenMatrix<T>::From(tmp);
eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradPositiveFunctor<T>()) - eigen_out;
LogcumsumexpKernel<T, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos);
eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp();
eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradNegativeFunctor<T>()) - eigen_out;
LogcumsumexpKernel<T, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg);
eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp();
auto eigen_d_x = EigenMatrix<T>::From(*d_x);
eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg;
}
} // namespace phi