From 19a7524febaf732108d6faad26724b265b56c298 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Fri, 10 Jun 2022 16:36:14 +0800 Subject: [PATCH] [Hackathon No.28] implement logcumsumexp (#42267) --- .../operators/{cumsum_op.cc => cum_op.cc} | 74 ++++- paddle/phi/infermeta/unary.cc | 12 +- paddle/phi/infermeta/unary.h | 12 +- paddle/phi/kernels/cpu/cum_kernel.cc | 271 +++++++++++++++++ paddle/phi/kernels/cpu/cumsum_kernel.cc | 143 --------- .../kernels/cpu/logcumsumexp_grad_kernel.cc | 28 ++ .../kernels/{cumsum_kernel.h => cum_kernel.h} | 9 + .../gpu/{cumsum_kernel.cu => cum_kernel.cu} | 119 ++++++-- .../kernels/gpu/logcumsumexp_grad_kernel.cu | 27 ++ .../phi/kernels/impl/logcumsumexp_grad_impl.h | 86 ++++++ paddle/phi/kernels/logcumsumexp_grad_kernel.h | 31 ++ paddle/phi/ops/compat/logcumsumexp_sig.cc | 39 +++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/test_logcumsumexp_op.py | 272 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 75 ++++- python/paddle/utils/code_gen/api.yaml | 11 +- python/paddle/utils/code_gen/backward.yaml | 10 + 19 files changed, 1036 insertions(+), 188 deletions(-) rename paddle/fluid/operators/{cumsum_op.cc => cum_op.cc} (51%) create mode 100644 paddle/phi/kernels/cpu/cum_kernel.cc delete mode 100644 paddle/phi/kernels/cpu/cumsum_kernel.cc create mode 100644 paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc rename paddle/phi/kernels/{cumsum_kernel.h => cum_kernel.h} (75%) rename paddle/phi/kernels/gpu/{cumsum_kernel.cu => cum_kernel.cu} (79%) create mode 100644 paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu create mode 100644 paddle/phi/kernels/impl/logcumsumexp_grad_impl.h create mode 100644 paddle/phi/kernels/logcumsumexp_grad_kernel.h create mode 100644 paddle/phi/ops/compat/logcumsumexp_sig.cc create mode 100644 python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py diff --git a/paddle/fluid/operators/cumsum_op.cc b/paddle/fluid/operators/cum_op.cc similarity index 51% rename from paddle/fluid/operators/cumsum_op.cc rename to paddle/fluid/operators/cum_op.cc index dbb703e7e874d..be001c43086cf 100644 --- a/paddle/fluid/operators/cumsum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( The cumulative sum of the elements along a given axis. By default, the first element of the result is the same of the first element of -the input. If exlusive is true, the first element of the result is 0. +the input. If exclusive is true, the first element of the result is 0. )DOC"); } }; @@ -74,17 +74,87 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { } }; +class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of logcumsumexp operator"); + AddOutput("Out", "Output of logcumsumexp operator"); + AddAttr("axis", + "The dimension to accumulate along. -1 means the last " + "dimension [default -1].") + .SetDefault(-1); + AddAttr("flatten", + "Whether to compute the logcumsumexp over the flattened array. " + "[default false].") + .SetDefault(false); + AddAttr("exclusive", + "Whether to perform exclusive logcumsumexp. [default false].") + .SetDefault(false); + AddAttr("reverse", + "If true, the logcumsumexp is performed in the reversed direction. " + "[default false].") + .SetDefault(false); + AddComment(R"DOC( +Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis. +By default, the first element of the result is the same of the first element of +the input. If exclusive is true, the first element of the result is the lowest finite value of the dtype of output tensor. +)DOC"); + } +}; + +class LogcumsumexpGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "logcumsumexp"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +template +class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("logcumsumexp_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis"))); + grad_op->SetAttr("flatten", + BOOST_GET_CONST(bool, this->GetAttr("flatten"))); + grad_op->SetAttr("exclusive", + BOOST_GET_CONST(bool, this->GetAttr("exclusive"))); + grad_op->SetAttr("reverse", + BOOST_GET_CONST(bool, this->GetAttr("reverse"))); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, - PD_INFER_META(phi::CumsumInferMeta)); + PD_INFER_META(phi::CumInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, + PD_INFER_META(phi::CumInferMeta)); REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker, ops::CumsumGradMaker, CumsumInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker, + ops::LogcumsumexpGradMaker, + ops::LogcumsumexpGradMaker, + LogcumsumexpInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp); REGISTER_OP_VERSION(cumsum).AddCheckpoint( R"ROC( diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0beb7223f212a..bc41a24c44562 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -235,12 +235,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_layout(x.layout()); } -void CumsumInferMeta(const MetaTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - MetaTensor* out) { +void CumInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out) { auto x_dims = x.dims(); if (flatten) { out->set_dims(phi::make_ddim({phi::product(x_dims)})); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a288b9371016f..a0cad3e628e3f 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); -void CumsumInferMeta(const MetaTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - MetaTensor* out); +void CumInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out); void DiagInferMeta(const MetaTensor& x, int offset, diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc new file mode 100644 index 0000000000000..cd171cc8fc5fc --- /dev/null +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -0,0 +1,271 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cum_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void ComputeImp(Device d, + const Dim& dims, + X x, + Out out, + int axis, + bool reverse, + bool exclusive, + Reducer reducer) { + if (!reverse) { + out.reshape(dims).device(d) = + x.reshape(dims).scan(axis, reducer, exclusive); + } else { + std::array rev; + rev.fill(false); + rev[axis] = reverse; + out.reshape(dims).device(d) = x.reshape(dims) + .reverse(rev) + .scan(axis, reducer, exclusive) + .reverse(rev); + } +} + +template +void ScanKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + Reducer reducer, + DenseTensor* out) { + auto out_dims = out->dims(); + + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), + out_dims.size() - 1, + axis)); + if (axis < 0) { + axis += out_dims.size(); + } + + dev_ctx.template Alloc(out); + + int pre = 1; + int post = 1; + int mid = out_dims[axis]; + for (int i = 0; i < axis; ++i) { + pre *= out_dims[i]; + } + for (int i = axis + 1; i < out_dims.size(); ++i) { + post *= out_dims[i]; + } + + auto x0 = EigenVector::Flatten(x); + auto out0 = EigenVector::Flatten(*out); + auto& place = *dev_ctx.eigen_device(); + + using IndexT = Eigen::DenseIndex; + if (pre == 1) { + if (post == 1) { + ComputeImp(place, + Eigen::DSizes(mid), + x0, + out0, + /* axis= */ 0, + reverse, + exclusive, + reducer); + } else { + ComputeImp(place, + Eigen::DSizes(mid, post), + x0, + out0, + /* axis= */ 0, + reverse, + exclusive, + reducer); + } + } else { + if (post == 1) { + ComputeImp(place, + Eigen::DSizes(pre, mid), + x0, + out0, + /* axis= */ 1, + reverse, + exclusive, + reducer); + } else { + ComputeImp(place, + Eigen::DSizes(pre, mid, post), + x0, + out0, + /* axis= */ 1, + reverse, + exclusive, + reducer); + } + } +} + +template +void CumsumKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Reducer = Eigen::internal::SumReducer; + auto reducer = Reducer(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); +} + +template +struct LogSumExp { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + auto mi = Eigen::internal::scalar_min_op()(a, b); + auto ma = Eigen::internal::scalar_max_op()(a, b); + + auto sub = Eigen::internal::scalar_difference_op(); + auto add = Eigen::internal::scalar_sum_op(); + auto exp = Eigen::internal::scalar_exp_op(); + auto log1p = Eigen::internal::scalar_log1p_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + + auto logsumexp = add(log1p(exp(sub(mi, ma))), ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? ma : logsumexp; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a, + const T& b) const { + auto mi = Eigen::internal::pmin(a, b); + auto ma = Eigen::internal::pmax(a, b); + using Eigen::internal::padd; + using Eigen::internal::pcmp_lt; + using Eigen::internal::pexp; + using Eigen::internal::plog1p; + using Eigen::internal::pset1; + using Eigen::internal::psub; + + auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma); + return pselect( + pcmp_lt(ma, pset1(Eigen::NumTraits::lowest())), ma, logsumexp); + } +}; + +template +struct LogSumExpReducer { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { + LogSumExp logsumexp; + *accum = logsumexp(*accum, t); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, + Packet* accum) const { + LogSumExp logsumexp; + *accum = logsumexp.packetOp(*accum, p); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return Eigen::NumTraits::lowest(); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { + return Eigen::internal::pset1(initialize()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { + return accum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + finalizePacket(const Packet& vaccum) const { + return vaccum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + finalizeBoth(const T saccum, const Packet& vaccum) const { + auto max_reducer = Eigen::internal::MaxReducer(); + auto sum_reducer = Eigen::internal::SumReducer(); + auto exp = Eigen::internal::scalar_exp_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + auto log = Eigen::internal::scalar_log_op(); + auto add = Eigen::internal::scalar_sum_op(); + + using Eigen::internal::pexp; + using Eigen::internal::psub; + + // `ma = max(x1, ..., xn)` + // If the max of all of the `xi` is `-infinity` then the result is + // -infinity. If the max is larger than `-infinity` then it's safe to use + // for normalization even if the other elements are `-infinity`. + // + // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))` + auto ma = max_reducer.finalizeBoth(saccum, vaccum); + auto logsumexp = add(log(sum_reducer.finalizeBoth( + exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))), + ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? initialize() : logsumexp; + } +}; + +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Reducer = LogSumExpReducer; + auto reducer = Reducer(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumsum, + CPU, + ALL_LAYOUT, + phi::CumsumKernel, + float, + double, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL( + logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/cumsum_kernel.cc b/paddle/phi/kernels/cpu/cumsum_kernel.cc deleted file mode 100644 index d32e18479aae9..0000000000000 --- a/paddle/phi/kernels/cpu/cumsum_kernel.cc +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/cumsum_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - -namespace phi { - -struct CumsumFunctor { - template - const typename X::TensorScanSumOp operator()(X x, - int axis, - bool exclusive) const { - return x.cumsum(axis, exclusive); - } -}; - -template -void ComputeImp(Device d, - const Dim& dims, - X x, - Out out, - int axis, - bool reverse, - bool exclusive) { - if (!reverse) { - out.reshape(dims).device(d) = - CumsumFunctor()(x.reshape(dims), axis, exclusive); - } else { - std::array rev; - rev.fill(false); - rev[axis] = reverse; - out.reshape(dims).device(d) = - CumsumFunctor()(x.reshape(dims).reverse(rev), axis, exclusive) - .reverse(rev); - } -} - -template -void CumsumKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - DenseTensor* out) { - auto out_dims = out->dims(); - - PADDLE_ENFORCE_EQ( - axis < out_dims.size() && axis >= (0 - out_dims.size()), - true, - phi::errors::OutOfRange( - "Attr(axis) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(axis) = %d.", - out_dims.size(), - out_dims.size() - 1, - axis)); - if (axis < 0) { - axis += out_dims.size(); - } - - dev_ctx.template Alloc(out); - - int pre = 1; - int post = 1; - int mid = out_dims[axis]; - for (int i = 0; i < axis; ++i) { - pre *= out_dims[i]; - } - for (int i = axis + 1; i < out_dims.size(); ++i) { - post *= out_dims[i]; - } - - auto x0 = EigenVector::Flatten(x); - auto out0 = EigenVector::Flatten(*out); - auto& place = *dev_ctx.eigen_device(); - - using IndexT = Eigen::DenseIndex; - if (pre == 1) { - if (post == 1) { - ComputeImp(place, - Eigen::DSizes(mid), - x0, - out0, - /* axis= */ 0, - reverse, - exclusive); - } else { - ComputeImp(place, - Eigen::DSizes(mid, post), - x0, - out0, - /* axis= */ 0, - reverse, - exclusive); - } - } else { - if (post == 1) { - ComputeImp(place, - Eigen::DSizes(pre, mid), - x0, - out0, - /* axis= */ 1, - reverse, - exclusive); - } else { - ComputeImp(place, - Eigen::DSizes(pre, mid, post), - x0, - out0, - /* axis= */ 1, - reverse, - exclusive); - } - } -} - -} // namespace phi - -PD_REGISTER_KERNEL(cumsum, - CPU, - ALL_LAYOUT, - phi::CumsumKernel, - float, - double, - int16_t, - int, - int64_t) {} diff --git a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc new file mode 100644 index 0000000000000..17f28b411bcdd --- /dev/null +++ b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" + +PD_REGISTER_KERNEL(logcumsumexp_grad, + CPU, + ALL_LAYOUT, + phi::LogcumsumexpGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cumsum_kernel.h b/paddle/phi/kernels/cum_kernel.h similarity index 75% rename from paddle/phi/kernels/cumsum_kernel.h rename to paddle/phi/kernels/cum_kernel.h index f105c94d559d8..38cdbd7787baf 100644 --- a/paddle/phi/kernels/cumsum_kernel.h +++ b/paddle/phi/kernels/cum_kernel.h @@ -27,4 +27,13 @@ void CumsumKernel(const Context& dev_ctx, bool reverse, DenseTensor* out); +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/cumsum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu similarity index 79% rename from paddle/phi/kernels/gpu/cumsum_kernel.cu rename to paddle/phi/kernels/gpu/cum_kernel.cu index 460aa37f8f995..ad86fd9ba49df 100644 --- a/paddle/phi/kernels/gpu/cumsum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -17,7 +17,7 @@ #include #include -#include "paddle/phi/kernels/cumsum_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" #ifdef __NVCC__ #include #endif @@ -82,19 +82,20 @@ __global__ void MatrixRowReverse(const T* matrix_data, } } -template +template struct BlockPrefixCallbackOp { // Running prefix - T running_total; - // Constructor - __device__ BlockPrefixCallbackOp(T running_total) - : running_total(running_total) {} + T running_total_; + Op op_; + + __device__ BlockPrefixCallbackOp(T running_total, Op op) + : running_total_(running_total), op_(op) {} + // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide - // scan. + // tid 0 is responsible for returning a value for seeding the block-wide scan. __device__ T operator()(T block_aggregate) { - T old_prefix = running_total; - running_total = old_prefix + block_aggregate; + T old_prefix = running_total_; + running_total_ = op_(old_prefix, block_aggregate); return old_prefix; } }; @@ -129,13 +130,36 @@ __global__ void MatrixTranspose(T* odata, } } -template +struct LogAddExp { + template + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return std::log(1 + std::exp(std::min(a, b) - std::max(a, b))) + + std::max(a, b); + } +}; + +template +struct Identity; + +template +struct Identity { + static constexpr T value = 0; +}; + +template +struct Identity { + static constexpr T value = std::numeric_limits::lowest(); +}; + +template __global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size, int outer_size, int scan_size, - bool exclusive) { + bool exclusive, + Op op) { // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types typedef cub:: BlockLoad @@ -154,7 +178,7 @@ __global__ void BlockScanKernel(T* d_out, int bx = blockIdx.x; int by = blockIdx.y; - BlockPrefixCallbackOp prefix_op(0); + BlockPrefixCallbackOp prefix_op(Identity::value, op); T block_aggregate = static_cast(0); // Obtain this block's segment of consecutive keys (blocked across threads) @@ -176,12 +200,11 @@ __global__ void BlockScanKernel(T* d_out, __syncthreads(); if (exclusive) { - T init_value = static_cast(0); BlockScanT(temp_storage.scan) - .ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + .ExclusiveScan(thread_keys, thread_keys, op, prefix_op); } else { BlockScanT(temp_storage.scan) - .InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + .InclusiveScan(thread_keys, thread_keys, op, prefix_op); } __syncthreads(); @@ -190,14 +213,15 @@ __global__ void BlockScanKernel(T* d_out, } } -template -void CumsumKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - DenseTensor* out) { +template +void ScanKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + Op op, + DenseTensor* out) { auto out_dims = out->dims(); auto size = x.numel(); @@ -219,7 +243,7 @@ void CumsumKernel(const Context& dev_ctx, // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. - if (size == out_dims[axis]) { + if (std::is_same::value && size == out_dims[axis]) { #ifdef __HIPCC__ const auto& policy = thrust::hip::par.on(dev_ctx.stream()); #else @@ -247,6 +271,7 @@ void CumsumKernel(const Context& dev_ctx, return; } + size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { @@ -299,17 +324,18 @@ void CumsumKernel(const Context& dev_ctx, } } if (!transpose && !reverse) { - BlockScanKernel<<>>( - out_data, in_data, outer_size, inner_size, scan_size, exclusive); + BlockScanKernel<<>>( + out_data, in_data, outer_size, inner_size, scan_size, exclusive, op); } else { - BlockScanKernel + BlockScanKernel <<>>(next_out_data, next_in_data, outer_size, inner_size, scan_size, - exclusive); + exclusive, + op); } swap_ptr(next_in_data, next_out_data); if (reverse) { @@ -325,6 +351,34 @@ void CumsumKernel(const Context& dev_ctx, } } +template +void CumsumKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Op = cub::Sum; + auto op = Op(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, op, out); +} + +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Op = LogAddExp; + auto op = Op(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, op, out); +} + } // namespace phi PD_REGISTER_KERNEL(cumsum, @@ -336,3 +390,10 @@ PD_REGISTER_KERNEL(cumsum, int16_t, int, int64_t) {} + +PD_REGISTER_KERNEL(logcumsumexp, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu new file mode 100644 index 0000000000000..43744210e32b7 --- /dev/null +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + +PD_REGISTER_KERNEL(logcumsumexp_grad, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h new file mode 100644 index 0000000000000..602f2248902cc --- /dev/null +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cum_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +struct LogGradPositiveFunctor { + HOSTDEVICE T operator()(const T& x) const { + const T kMin = std::numeric_limits::lowest(); + return x > 0 ? std::log(x) : kMin; + } +}; + +template +struct LogGradNegativeFunctor { + HOSTDEVICE T operator()(const T& x) const { + const T kMin = std::numeric_limits::lowest(); + return x < 0 ? std::log(-x) : kMin; + } +}; + +template +void LogcumsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* d_x) { + reverse = !reverse; + dev_ctx.template Alloc(d_x); + + auto eigen_x = EigenVector::Flatten(x); + auto eigen_out = EigenVector::Flatten(out); + auto eigen_d_out = EigenVector::Flatten(d_out); + auto& place = *dev_ctx.eigen_device(); + + DenseTensor output_pos; + output_pos.Resize(d_out.dims()); + dev_ctx.template Alloc(&output_pos); + auto eigen_output_pos = EigenVector::Flatten(output_pos); + DenseTensor output_neg; + output_neg.Resize(d_out.dims()); + dev_ctx.template Alloc(&output_neg); + auto eigen_output_neg = EigenVector::Flatten(output_neg); + DenseTensor tmp; + tmp.Resize(d_out.dims()); + dev_ctx.template Alloc(&tmp); + auto eigen_tmp = EigenVector::Flatten(tmp); + + eigen_tmp.device(place) = + eigen_d_out.unaryExpr(LogGradPositiveFunctor()) - eigen_out; + LogcumsumexpKernel( + dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos); + eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp(); + + eigen_tmp.device(place) = + eigen_d_out.unaryExpr(LogGradNegativeFunctor()) - eigen_out; + LogcumsumexpKernel( + dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); + eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); + + auto eigen_d_x = EigenVector::Flatten(*d_x); + eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; +} +} // namespace phi diff --git a/paddle/phi/kernels/logcumsumexp_grad_kernel.h b/paddle/phi/kernels/logcumsumexp_grad_kernel.h new file mode 100644 index 0000000000000..e78a79550657e --- /dev/null +++ b/paddle/phi/kernels/logcumsumexp_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogcumsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* d_x); +} diff --git a/paddle/phi/ops/compat/logcumsumexp_sig.cc b/paddle/phi/ops/compat/logcumsumexp_sig.cc new file mode 100644 index 0000000000000..2c790903b6333 --- /dev/null +++ b/paddle/phi/ops/compat/logcumsumexp_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LogcumsumexpOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp", + {"X"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"Out"}); +} + +KernelSignature LogcumsumexpGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp_grad", + {"X", "Out", "Out@GRAD"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, + phi::LogcumsumexpGradOpArgumentMapping); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 75ec75cc43100..b2a94e62a1e0b 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -193,6 +193,7 @@ from .tensor.math import cosh # noqa: F401 from .tensor.math import cumsum # noqa: F401 from .tensor.math import cumprod # noqa: F401 +from .tensor.math import logcumsumexp # noqa: F401 from .tensor.math import logit # noqa: F401 from .tensor.math import exp # noqa: F401 from .tensor.math import expm1 # noqa: F401 @@ -407,6 +408,7 @@ 'eye', 'cumsum', 'cumprod', + 'logcumsumexp', 'logit', 'sign', 'is_empty', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index b391837e54671..34971cf11941f 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -684,6 +684,7 @@ endif() foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) +set_tests_properties(test_logcumsumexp_op PROPERTIES TIMEOUT 30) py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4) if(WITH_GPU diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py new file mode 100644 index 0000000000000..ebc350d13c673 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -0,0 +1,272 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from typing import Optional +import unittest +import itertools +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import _test_eager_guard +from op_test import OpTest + + +def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int] = None): + return np.log(np.cumsum(np.exp(x), axis=axis)) + + +def np_logcumsumexp(x: np.ndarray, + axis: Optional[int] = None, + flatten: Optional[bool] = None, + reverse: bool = False, + exclusive: bool = False): + # `flatten` aligns with c++ op + if flatten: + assert axis in [0, None] + axis = None + + x = np.copy(x) + + if axis is None: + x = x.flatten() + axis = 0 + + if reverse: + x = np.flip(x, axis) + + dimensions = [range(dim) for dim in x.shape[:axis]] + + if exclusive: + x = np.roll(x, 1, axis) + for prefix_dim in itertools.product(*dimensions): + x[prefix_dim][0] = np.finfo(x.dtype).min + + for prefix_dim in itertools.product(*dimensions): + arr = x[prefix_dim] + for dim in range(1, arr.shape[0]): + arr[dim] = np.logaddexp(arr[dim - 1], arr[dim]) + + if reverse: + x = np.flip(x, axis) + + return x + + +def np_logcumsumexp_grad( + x: np.ndarray, + dout: np.ndarray, + axis: Optional[int] = None, + flatten: Optional[bool] = None, + reverse: bool = False, + exclusive: bool = False, +): + out = np_logcumsumexp(x, axis, flatten, reverse, exclusive) + log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min) + log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min) + + output_pos = np.exp( + np_logcumsumexp(log_grad_positive - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) + output_neg = np.exp( + np_logcumsumexp(log_grad_negative - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) + + return output_pos - output_neg + + +class TestLogcumsumexp(unittest.TestCase): + + def run_imperative(self): + data_np = np.arange(12, dtype=np.float32).reshape(3, 4) + data = paddle.to_tensor(data_np) + + y = paddle.logcumsumexp(data) + z = np_logcumsumexp(data_np) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, axis=0) + z = np_logcumsumexp(data_np, axis=0) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, axis=-1) + z = np_logcumsumexp(data_np, axis=-1) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, dtype='float32') + self.assertTrue(y.dtype == core.VarDesc.VarType.FP32) + + y = paddle.logcumsumexp(data, axis=-2) + z = np_logcumsumexp(data_np, axis=-2) + self.assertTrue(np.allclose(z, y.numpy())) + + with self.assertRaises(IndexError): + y = paddle.logcumsumexp(data, axis=-3) + + with self.assertRaises(IndexError): + y = paddle.logcumsumexp(data, axis=2) + + data_np = np.arange(10000, 10024, dtype=np.float32) + data = paddle.to_tensor(data_np) + y = paddle.logcumsumexp(data) + z = np_naive_logcumsumexp(data_np) + # check that naive algorithm overflows + self.assertTrue(all(z == np.inf)) + z = np_logcumsumexp(data_np) + # check that our algorithm doesn't overflow + self.assertTrue(all(z != np.inf)) + self.assertTrue(np.allclose(z, y.numpy())) + + def run_static(self, use_gpu=False): + with fluid.program_guard(fluid.Program()): + data_np = np.random.random((5, 4)).astype(np.float32) + x = paddle.static.data('X', [5, 4]) + y = paddle.logcumsumexp(x) + y2 = paddle.logcumsumexp(x, axis=0) + y3 = paddle.logcumsumexp(x, axis=-1) + y4 = paddle.logcumsumexp(x, dtype='float64') + y5 = paddle.logcumsumexp(x, axis=-2) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run(feed={'X': data_np}, + fetch_list=[ + y.name, + y2.name, + y3.name, + y4.name, + y5.name, + ]) + + z = np_logcumsumexp(data_np) + self.assertTrue(np.allclose(z, out[0])) + z = np_logcumsumexp(data_np, axis=0) + self.assertTrue(np.allclose(z, out[1])) + z = np_logcumsumexp(data_np, axis=-1) + self.assertTrue(np.allclose(z, out[2])) + self.assertTrue(out[3].dtype == np.float64) + z = np_logcumsumexp(data_np, axis=-2) + self.assertTrue(np.allclose(z, out[4])) + + def test_cpu(self): + paddle.disable_static(paddle.fluid.CPUPlace()) + self.run_imperative() + paddle.enable_static() + + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + paddle.disable_static(paddle.fluid.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + + self.run_static(use_gpu=True) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = paddle.static.data('x', [3, 4]) + y = paddle.logcumsumexp(x, name='out') + self.assertTrue('out' in y.name) + + def test_type_error(self): + with fluid.program_guard(fluid.Program()): + + with self.assertRaises(TypeError): + data_np = np.random.random((100, 100), dtype=np.int32) + x = paddle.static.data('X', [100, 100], dtype='int32') + y = paddle.logcumsumexp(x) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) + + +class BaseTestCases: + + class BaseOpTest(OpTest): + + def setUp(self): + self.op_type = "logcumsumexp" + input, attrs = self.input_and_attrs() + self.inputs = {'X': input} + self.attrs = attrs + self.outputs = {'Out': np_logcumsumexp(input, **attrs)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], + 'Out', + user_defined_grads=[ + np_logcumsumexp_grad(self.inputs['X'], + 1 / self.inputs['X'].size, + **self.attrs) + ]) + + def input_and_attrs(self): + raise NotImplementedError() + + +class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): + + def input_and_attrs(self): + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True + } + + +class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): + + def input_and_attrs(self): + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 1, + 'reverse': True + } + + +class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): + + def input_and_attrs(self): + return np.arange(100, dtype=np.float64).reshape(10, 10), {'axis': 1} + + +class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): + + def input_and_attrs(self): + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True, + 'exclusive': True + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 3ea3ba4982599..08b0af26bd46e 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -139,6 +139,7 @@ from .math import cosh # noqa: F401 from .math import cumsum # noqa: F401 from .math import cumprod # noqa: F401 +from .math import logcumsumexp # noqa: F401 from .math import logit # noqa: F401 from .math import exp # noqa: F401 from .math import exp_ # noqa: F401 @@ -310,6 +311,7 @@ 'cosh', 'cumsum', 'cumprod', + 'logcumsumexp', 'logit', 'exp', 'exp_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2ef324395b26a..4611cbb20c96a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2909,7 +2909,7 @@ def cumsum(x, axis=None, dtype=None, name=None): The cumulative sum of the elements along a given axis. **Note**: - The first element of the result is the same of the first element of the input. + The first element of the result is the same as the first element of the input. Args: x (Tensor): The input tensor needed to be cumsumed. @@ -2970,6 +2970,79 @@ def cumsum(x, axis=None, dtype=None, name=None): _cum_sum_ = generate_layer_fn('cumsum') return _cum_sum_(**kwargs) + +def logcumsumexp(x, axis=None, dtype=None, name=None): + r""" + The logarithm of the cumulative summation of the exponentiation of the elements along a given axis. + + For summation index j given by `axis` and other indices i, the result is + + .. math:: + + logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij}) + + Note: + The first element of the result is the same as the first element of the input. + + Args: + x (Tensor): The input tensor. + axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. + dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the result of logcumsumexp operator. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.arange(12, dtype='float64') + data = paddle.reshape(data, (3, 4)) + + y = paddle.logcumsumexp(data) + # [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934 + # 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ] + + y = paddle.logcumsumexp(data, axis=0) + # [[ 0. 1. 2. 3. ] + # [ 4.01815 5.01815 6.01815 7.01815 ] + # [ 8.018479 9.018479 10.018479 11.018479]] + + y = paddle.logcumsumexp(data, axis=-1) + # [[ 0. 1.3132617 2.4076061 3.4401898] + # [ 4. 5.3132615 6.407606 7.44019 ] + # [ 8. 9.313262 10.407606 11.440189 ]] + + y = paddle.logcumsumexp(data, dtype='float64') + print(y.dtype) + # paddle.float64 + """ + if axis is None: + flatten = True + else: + flatten = False + if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): + x = cast(x, dtype) + + if in_dygraph_mode(): + if axis is None: axis = -1 + return _C_ops.final_state_logcumsumexp(x, axis, flatten, False, False) + if _in_legacy_dygraph(): + if axis is None: + return _C_ops.logcumsumexp(x, 'flatten', flatten) + else: + return _C_ops.logcumsumexp(x, 'axis', axis, 'flatten', flatten) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp") + + helper = LayerHelper('logcumsumexp', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='logcumsumexp', inputs={'X': x}, outputs={'Out': out}, attrs={'axis': axis, 'flatten': flatten}) + return out + + def cumprod(x, dim=None, dtype=None, name=None): """ Compute the cumulative product of the input tensor x along a given dimension dim. diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index fd000567c507b..dfb5f6acedc01 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -482,7 +482,7 @@ args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) output : Tensor(out) infer_meta : - func : CumsumInferMeta + func : CumInferMeta kernel : func : cumsum backward : cumsum_grad @@ -1259,6 +1259,15 @@ func : log_softmax backward : log_softmax_grad +- api : logcumsumexp + args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(out) + infer_meta : + func : CumInferMeta + kernel : + func : logcumsumexp + backward : logcumsumexp_grad + # logical_and - api : logical_and args : (Tensor x, Tensor y) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 81641ac19f7b5..0d14560d605e1 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1137,6 +1137,16 @@ kernel : func : log_softmax_grad +- backward_api : logcumsumexp_grad + forward : logcumsumexp(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + args : (Tensor x, Tensor out, Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(x_grad) + kernel : + func : logcumsumexp_grad + - backward_api : logit_grad forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out) args : (Tensor x, Tensor out_grad, float eps)