From 4054f7c1118a7f8185c76a852afab5d73c39d3a6 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 14:46:05 +0800 Subject: [PATCH] polish --- paddle/phi/kernels/cpu/cum_kernel.cc | 13 +++++++------ paddle/phi/kernels/gpu/cum_kernel.cu | 13 ------------- paddle/phi/kernels/impl/logcumsumexp_grad_impl.h | 1 + 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 846d148bf1397..e44bc7b783190 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -48,14 +48,14 @@ void ComputeImp(Device d, } } -template +template void ScanKernel(const Context& dev_ctx, const DenseTensor& x, int axis, bool flatten, bool exclusive, bool reverse, - Op op, + Reducer reducer, DenseTensor* out) { auto out_dims = out->dims(); @@ -98,7 +98,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 0, reverse, exclusive, - op); + reducer); } else { ComputeImp(place, Eigen::DSizes(mid, post), @@ -107,7 +107,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 0, reverse, exclusive, - op); + reducer); } } else { if (post == 1) { @@ -118,7 +118,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 1, reverse, exclusive, - op); + reducer); } else { ComputeImp(place, Eigen::DSizes(pre, mid, post), @@ -127,7 +127,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 1, reverse, exclusive, - op); + reducer); } } } @@ -146,6 +146,7 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } +// Copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index db222cdaf0ffd..3846f832b6fa1 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -141,14 +141,6 @@ struct LogAddExp { } }; -struct Prod { - template - __host__ __device__ __forceinline__ T operator()(const T& a, - const T& b) const { - return a * b; - } -}; - template struct Identity; @@ -157,11 +149,6 @@ struct Identity { static constexpr T value = 0; }; -template -struct Identity { - static constexpr T value = 1; -}; - template struct Identity { static constexpr T value = std::numeric_limits::lowest(); diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index f62b027a2aa9d..7e3cd7295e5f9 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -47,6 +47,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x) { + // Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py reverse = !reverse; dev_ctx.template Alloc(d_x);