Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed Apr 26, 2022
1 parent 5c3b6bb commit 4054f7c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 19 deletions.
13 changes: 7 additions & 6 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Expand Up @@ -48,14 +48,14 @@ void ComputeImp(Device d,
}
}

template <typename T, typename Context, typename Op>
template <typename T, typename Context, typename Reducer>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
Op op,
Reducer reducer,
DenseTensor* out) {
auto out_dims = out->dims();

Expand Down Expand Up @@ -98,7 +98,7 @@ void ScanKernel(const Context& dev_ctx,
/* axis= */ 0,
reverse,
exclusive,
op);
reducer);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 2>(mid, post),
Expand All @@ -107,7 +107,7 @@ void ScanKernel(const Context& dev_ctx,
/* axis= */ 0,
reverse,
exclusive,
op);
reducer);
}
} else {
if (post == 1) {
Expand All @@ -118,7 +118,7 @@ void ScanKernel(const Context& dev_ctx,
/* axis= */ 1,
reverse,
exclusive,
op);
reducer);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 3>(pre, mid, post),
Expand All @@ -127,7 +127,7 @@ void ScanKernel(const Context& dev_ctx,
/* axis= */ 1,
reverse,
exclusive,
op);
reducer);
}
}
}
Expand All @@ -146,6 +146,7 @@ void CumsumKernel(const Context& dev_ctx,
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}

// Copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
Expand Down
13 changes: 0 additions & 13 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Expand Up @@ -141,14 +141,6 @@ struct LogAddExp {
}
};

struct Prod {
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T& a,
const T& b) const {
return a * b;
}
};

template <typename T, typename op>
struct Identity;

Expand All @@ -157,11 +149,6 @@ struct Identity<T, cub::Sum> {
static constexpr T value = 0;
};

template <typename T>
struct Identity<T, Prod> {
static constexpr T value = 1;
};

template <typename T>
struct Identity<T, LogAddExp> {
static constexpr T value = std::numeric_limits<T>::lowest();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/impl/logcumsumexp_grad_impl.h
Expand Up @@ -47,6 +47,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
bool exclusive,
bool reverse,
DenseTensor* d_x) {
// Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py
reverse = !reverse;
dev_ctx.template Alloc<T>(d_x);

Expand Down

0 comments on commit 4054f7c

Please sign in to comment.