From 5c3b6bb12e674c6e535e703864683286c41b01d3 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 14:36:18 +0800 Subject: [PATCH 01/21] implement logcumsumexp --- paddle/fluid/operators/cumsum_op.cc | 68 ++++- paddle/phi/infermeta/unary.cc | 12 +- paddle/phi/infermeta/unary.h | 12 +- paddle/phi/kernels/cpu/cum_kernel.cc | 275 ++++++++++++++++++ paddle/phi/kernels/cpu/cumsum_kernel.cc | 143 --------- .../kernels/cpu/logcumsumexp_grad_kernel.cc | 26 ++ paddle/phi/kernels/cumsum_kernel.h | 9 + .../gpu/{cumsum_kernel.cu => cum_kernel.cu} | 171 +++++++---- .../kernels/gpu/logcumsumexp_grad_kernel.cu | 26 ++ .../phi/kernels/impl/logcumsumexp_grad_impl.h | 86 ++++++ python/paddle/__init__.py | 2 + .../tests/unittests/test_logcumsumexp_op.py | 154 ++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 28 ++ python/paddle/utils/code_gen/api.yaml | 11 +- python/paddle/utils/code_gen/backward.yaml | 10 + 16 files changed, 814 insertions(+), 221 deletions(-) create mode 100644 paddle/phi/kernels/cpu/cum_kernel.cc delete mode 100644 paddle/phi/kernels/cpu/cumsum_kernel.cc create mode 100644 paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc rename paddle/phi/kernels/gpu/{cumsum_kernel.cu => cum_kernel.cu} (74%) create mode 100644 paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu create mode 100644 paddle/phi/kernels/impl/logcumsumexp_grad_impl.h create mode 100644 python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py diff --git a/paddle/fluid/operators/cumsum_op.cc b/paddle/fluid/operators/cumsum_op.cc index 11633fb0b8703..592adf8971c67 100644 --- a/paddle/fluid/operators/cumsum_op.cc +++ b/paddle/fluid/operators/cumsum_op.cc @@ -26,6 +26,19 @@ class CumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; }; +class CumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "logsumexp"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -74,17 +87,70 @@ class CumsumGradMaker : public framework::SingleGradOpMaker { } }; +class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Input of logcumsumexp operator"); + AddOutput("Out", "Output of logcumsumexp operator"); + AddAttr("axis", + "The dimension to accumulate along. -1 means the last " + "dimension [default -1].") + .SetDefault(-1); + AddAttr("flatten", + "Whether to compute the logcumsumexp over the flattened array. " + "[default false].") + .SetDefault(false); + AddAttr("exclusive", + "Whether to perform exclusive logcumsumexp. [default false].") + .SetDefault(false); + AddAttr("reverse", + "If true, the logcumsumexp is performed in the reversed direction. " + "[default false].") + .SetDefault(false); + AddComment(R"DOC( +Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis. +By default, the first element of the result is the same of the first element of +the input. If exlusive is true, the first element of the result is the minimum value of dtype. +)DOC"); + } +}; + +template +class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("logcumsumexp_grad"); + grad_op->SetInput("X", this->OutputGrad("Out")); + grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis"))); + grad_op->SetAttr("flatten", + BOOST_GET_CONST(bool, this->GetAttr("flatten"))); + grad_op->SetAttr("reverse", + BOOST_GET_CONST(bool, this->GetAttr("reverse"))); + grad_op->SetAttr("exclusive", + BOOST_GET_CONST(bool, this->GetAttr("exclusive"))); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, - PD_INFER_META(phi::CumsumInferMeta)); + PD_INFER_META(phi::CumInferMeta)); REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker, ops::CumsumGradMaker, CumsumInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker, + ops::LogcumsumexpGradMaker, + ops::LogcumsumexpGradMaker, + CumsumInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp); REGISTER_OP_VERSION(cumsum) .AddCheckpoint( diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 400c56db3efc2..937cbb6e64eff 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -234,12 +234,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_layout(x.layout()); } -void CumsumInferMeta(const MetaTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - MetaTensor* out) { +void CumInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out) { auto x_dims = x.dims(); if (flatten) { out->set_dims(phi::make_ddim({phi::product(x_dims)})); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c67eb2068d8bf..87eb1478442a3 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); -void CumsumInferMeta(const MetaTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - MetaTensor* out); +void CumInferMeta(const MetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + MetaTensor* out); void DiagInferMeta(const MetaTensor& x, int offset, diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc new file mode 100644 index 0000000000000..846d148bf1397 --- /dev/null +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cumsum_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void ComputeImp(Device d, + const Dim& dims, + X x, + Out out, + int axis, + bool reverse, + bool exclusive, + Reducer reducer) { + if (!reverse) { + out.reshape(dims).device(d) = + x.reshape(dims).scan(axis, reducer, exclusive); + } else { + std::array rev; + rev.fill(false); + rev[axis] = reverse; + out.reshape(dims).device(d) = x.reshape(dims) + .reverse(rev) + .scan(axis, reducer, exclusive) + .reverse(rev); + } +} + +template +void ScanKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + Op op, + DenseTensor* out) { + auto out_dims = out->dims(); + + PADDLE_ENFORCE_EQ( + axis < out_dims.size() && axis >= (0 - out_dims.size()), + true, + phi::errors::OutOfRange( + "Attr(axis) is out of range, It's expected " + "to be in range of [-%d, %d]. But received Attr(axis) = %d.", + out_dims.size(), + out_dims.size() - 1, + axis)); + if (axis < 0) { + axis += out_dims.size(); + } + + dev_ctx.template Alloc(out); + + int pre = 1; + int post = 1; + int mid = out_dims[axis]; + for (int i = 0; i < axis; ++i) { + pre *= out_dims[i]; + } + for (int i = axis + 1; i < out_dims.size(); ++i) { + post *= out_dims[i]; + } + + auto x0 = EigenVector::Flatten(x); + auto out0 = EigenVector::Flatten(*out); + auto& place = *dev_ctx.eigen_device(); + + using IndexT = Eigen::DenseIndex; + if (pre == 1) { + if (post == 1) { + ComputeImp(place, + Eigen::DSizes(mid), + x0, + out0, + /* axis= */ 0, + reverse, + exclusive, + op); + } else { + ComputeImp(place, + Eigen::DSizes(mid, post), + x0, + out0, + /* axis= */ 0, + reverse, + exclusive, + op); + } + } else { + if (post == 1) { + ComputeImp(place, + Eigen::DSizes(pre, mid), + x0, + out0, + /* axis= */ 1, + reverse, + exclusive, + op); + } else { + ComputeImp(place, + Eigen::DSizes(pre, mid, post), + x0, + out0, + /* axis= */ 1, + reverse, + exclusive, + op); + } + } +} + +template +void CumsumKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Reducer = Eigen::internal::SumReducer; + auto reducer = Reducer(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); +} + +template +struct LogSumExp { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + auto mi = Eigen::internal::scalar_min_op()(a, b); + auto ma = Eigen::internal::scalar_max_op()(a, b); + + auto sub = Eigen::internal::scalar_difference_op(); + auto add = Eigen::internal::scalar_sum_op(); + auto exp = Eigen::internal::scalar_exp_op(); + auto log1p = Eigen::internal::scalar_log1p_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + + auto logsumexp = add(log1p(exp(sub(mi, ma))), ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? ma : logsumexp; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a, + const T& b) const { + auto mi = Eigen::internal::pmin(a, b); + auto ma = Eigen::internal::pmax(a, b); + using Eigen::internal::padd; + using Eigen::internal::pcmp_lt; + using Eigen::internal::pexp; + using Eigen::internal::plog1p; + using Eigen::internal::pset1; + using Eigen::internal::psub; + + auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma); + return pselect( + pcmp_lt(ma, pset1(Eigen::NumTraits::lowest())), ma, logsumexp); + } +}; + +template +struct LogSumExpReducer { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { + LogSumExp logsumexp; + *accum = logsumexp(*accum, t); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, + Packet* accum) const { + LogSumExp logsumexp; + *accum = logsumexp.packetOp(*accum, p); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return -Eigen::NumTraits::infinity(); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { + return Eigen::internal::pset1(initialize()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { + return accum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + finalizePacket(const Packet& vaccum) const { + return vaccum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + finalizeBoth(const T saccum, const Packet& vaccum) const { + auto max_reducer = Eigen::internal::MaxReducer(); + auto sum_reducer = Eigen::internal::SumReducer(); + auto exp = Eigen::internal::scalar_exp_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + auto log = Eigen::internal::scalar_log_op(); + auto add = Eigen::internal::scalar_sum_op(); + + using Eigen::internal::pexp; + using Eigen::internal::psub; + + // `ma = max(x1, ..., xn)` + // If the max of all of the `xi` is `-infinity` then the result is + // -infinity. If the max is larger than `-infinity` then it's safe to use + // for normalization even if the other elements are `-infinity`. + // + // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))` + auto ma = max_reducer.finalizeBoth(saccum, vaccum); + auto logsumexp = add(log(sum_reducer.finalizeBoth( + exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))), + ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? initialize() : logsumexp; + } +}; + +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Reducer = LogSumExpReducer; + auto reducer = Reducer(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(cumsum, + CPU, + ALL_LAYOUT, + phi::CumsumKernel, + float, + double, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(logcumsumexp, + CPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/cumsum_kernel.cc b/paddle/phi/kernels/cpu/cumsum_kernel.cc deleted file mode 100644 index d32e18479aae9..0000000000000 --- a/paddle/phi/kernels/cpu/cumsum_kernel.cc +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/cumsum_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" - -namespace phi { - -struct CumsumFunctor { - template - const typename X::TensorScanSumOp operator()(X x, - int axis, - bool exclusive) const { - return x.cumsum(axis, exclusive); - } -}; - -template -void ComputeImp(Device d, - const Dim& dims, - X x, - Out out, - int axis, - bool reverse, - bool exclusive) { - if (!reverse) { - out.reshape(dims).device(d) = - CumsumFunctor()(x.reshape(dims), axis, exclusive); - } else { - std::array rev; - rev.fill(false); - rev[axis] = reverse; - out.reshape(dims).device(d) = - CumsumFunctor()(x.reshape(dims).reverse(rev), axis, exclusive) - .reverse(rev); - } -} - -template -void CumsumKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - DenseTensor* out) { - auto out_dims = out->dims(); - - PADDLE_ENFORCE_EQ( - axis < out_dims.size() && axis >= (0 - out_dims.size()), - true, - phi::errors::OutOfRange( - "Attr(axis) is out of range, It's expected " - "to be in range of [-%d, %d]. But received Attr(axis) = %d.", - out_dims.size(), - out_dims.size() - 1, - axis)); - if (axis < 0) { - axis += out_dims.size(); - } - - dev_ctx.template Alloc(out); - - int pre = 1; - int post = 1; - int mid = out_dims[axis]; - for (int i = 0; i < axis; ++i) { - pre *= out_dims[i]; - } - for (int i = axis + 1; i < out_dims.size(); ++i) { - post *= out_dims[i]; - } - - auto x0 = EigenVector::Flatten(x); - auto out0 = EigenVector::Flatten(*out); - auto& place = *dev_ctx.eigen_device(); - - using IndexT = Eigen::DenseIndex; - if (pre == 1) { - if (post == 1) { - ComputeImp(place, - Eigen::DSizes(mid), - x0, - out0, - /* axis= */ 0, - reverse, - exclusive); - } else { - ComputeImp(place, - Eigen::DSizes(mid, post), - x0, - out0, - /* axis= */ 0, - reverse, - exclusive); - } - } else { - if (post == 1) { - ComputeImp(place, - Eigen::DSizes(pre, mid), - x0, - out0, - /* axis= */ 1, - reverse, - exclusive); - } else { - ComputeImp(place, - Eigen::DSizes(pre, mid, post), - x0, - out0, - /* axis= */ 1, - reverse, - exclusive); - } - } -} - -} // namespace phi - -PD_REGISTER_KERNEL(cumsum, - CPU, - ALL_LAYOUT, - phi::CumsumKernel, - float, - double, - int16_t, - int, - int64_t) {} diff --git a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc new file mode 100644 index 0000000000000..7d18bec424030 --- /dev/null +++ b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" + +PD_REGISTER_KERNEL(logcumsumexp_grad, + CPU, + ALL_LAYOUT, + phi::LogcumsumexpGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cumsum_kernel.h b/paddle/phi/kernels/cumsum_kernel.h index f105c94d559d8..38cdbd7787baf 100644 --- a/paddle/phi/kernels/cumsum_kernel.h +++ b/paddle/phi/kernels/cumsum_kernel.h @@ -27,4 +27,13 @@ void CumsumKernel(const Context& dev_ctx, bool reverse, DenseTensor* out); +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/cumsum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu similarity index 74% rename from paddle/phi/kernels/gpu/cumsum_kernel.cu rename to paddle/phi/kernels/gpu/cum_kernel.cu index e04f2b5f87658..db222cdaf0ffd 100644 --- a/paddle/phi/kernels/gpu/cumsum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/cumsum_kernel.h" - #include #include #include #include + +#include "paddle/phi/kernels/cumsum_kernel.h" #ifdef __NVCC__ #include #endif @@ -84,19 +84,20 @@ __global__ void MatrixRowReverse(const T* matrix_data, } } -template +template struct BlockPrefixCallbackOp { // Running prefix - T running_total; - // Constructor - __device__ BlockPrefixCallbackOp(T running_total) - : running_total(running_total) {} + T running_total_; + Op op_; + + __device__ BlockPrefixCallbackOp(T running_total, Op op) + : running_total_(running_total), op_(op) {} + // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide - // scan. + // tid 0 is responsible for returning a value for seeding the block-wide scan. __device__ T operator()(T block_aggregate) { - T old_prefix = running_total; - running_total = old_prefix + block_aggregate; + T old_prefix = running_total_; + running_total_ = op_(old_prefix, block_aggregate); return old_prefix; } }; @@ -131,13 +132,49 @@ __global__ void MatrixTranspose(T* odata, } } -template +struct LogAddExp { + template + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return std::log(1 + std::exp(std::min(a, b) - std::max(a, b))) + + std::max(a, b); + } +}; + +struct Prod { + template + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return a * b; + } +}; + +template +struct Identity; + +template +struct Identity { + static constexpr T value = 0; +}; + +template +struct Identity { + static constexpr T value = 1; +}; + +template +struct Identity { + static constexpr T value = std::numeric_limits::lowest(); +}; + +template __global__ void BlockScanKernel(T* d_out, const T* d_in, int inner_size, int outer_size, int scan_size, - bool exclusive) { + bool exclusive, + Op op) { // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types typedef cub:: BlockLoad @@ -156,7 +193,7 @@ __global__ void BlockScanKernel(T* d_out, int bx = blockIdx.x; int by = blockIdx.y; - BlockPrefixCallbackOp prefix_op(0); + BlockPrefixCallbackOp prefix_op(Identity::value, op); T block_aggregate = static_cast(0); // Obtain this block's segment of consecutive keys (blocked across threads) @@ -178,12 +215,11 @@ __global__ void BlockScanKernel(T* d_out, __syncthreads(); if (exclusive) { - T init_value = static_cast(0); BlockScanT(temp_storage.scan) - .ExclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + .ExclusiveScan(thread_keys, thread_keys, op, prefix_op); } else { BlockScanT(temp_storage.scan) - .InclusiveScan(thread_keys, thread_keys, cub::Sum(), prefix_op); + .InclusiveScan(thread_keys, thread_keys, op, prefix_op); } __syncthreads(); @@ -192,14 +228,15 @@ __global__ void BlockScanKernel(T* d_out, } } -template -void CumsumKernel(const Context& dev_ctx, - const DenseTensor& x, - int axis, - bool flatten, - bool exclusive, - bool reverse, - DenseTensor* out) { +template +void ScanKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + Op op, + DenseTensor* out) { auto out_dims = out->dims(); auto size = x.numel(); @@ -219,36 +256,6 @@ void CumsumKernel(const Context& dev_ctx, T* out_data = dev_ctx.template Alloc(out); const T* in_data = x.data(); - // Use thrust for parallel acceleration when the input size is equal to the - // length of the ‘axis’ dimension. - if (size == out_dims[axis]) { -#ifdef __HIPCC__ - const auto& policy = thrust::hip::par.on(dev_ctx.stream()); -#else - const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); -#endif - if (reverse) { - thrust::reverse_iterator> reversed_in( - thrust::device_pointer_cast(in_data) + size); - thrust::reverse_iterator> reversed_out( - thrust::device_pointer_cast(out_data) + size); - if (exclusive) { - thrust::exclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } else { - thrust::inclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } - } else { - if (exclusive) { - thrust::exclusive_scan(policy, in_data, in_data + size, out_data); - } else { - thrust::inclusive_scan(policy, in_data, in_data + size, out_data); - } - } - return; - } - size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { @@ -300,17 +307,18 @@ void CumsumKernel(const Context& dev_ctx, } } if (!transpose && !reverse) { - BlockScanKernel<<>>( - out_data, in_data, outer_size, inner_size, scan_size, exclusive); + BlockScanKernel<<>>( + out_data, in_data, outer_size, inner_size, scan_size, exclusive, op); } else { - BlockScanKernel<<>>( - next_out_data, - next_in_data, - outer_size, - inner_size, - scan_size, - exclusive); + BlockScanKernel + <<>>(next_out_data, + next_in_data, + outer_size, + inner_size, + scan_size, + exclusive, + op); } swap_ptr(next_in_data, next_out_data); if (reverse) { @@ -326,6 +334,34 @@ void CumsumKernel(const Context& dev_ctx, } } +template +void CumsumKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Op = cub::Sum; + auto op = Op(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, op, out); +} + +template +void LogcumsumexpKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* out) { + using Op = LogAddExp; + auto op = Op(); + ScanKernel( + dev_ctx, x, axis, flatten, exclusive, reverse, op, out); +} + } // namespace phi PD_REGISTER_KERNEL(cumsum, @@ -337,3 +373,10 @@ PD_REGISTER_KERNEL(cumsum, int16_t, int, int64_t) {} + +PD_REGISTER_KERNEL(logcumsumexp, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu new file mode 100644 index 0000000000000..cee562696cb40 --- /dev/null +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" + +PD_REGISTER_KERNEL(logcumsumexp_grad, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h new file mode 100644 index 0000000000000..f62b027a2aa9d --- /dev/null +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cumsum_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +struct LogGradPositiveFunctor { + HOSTDEVICE T operator()(const T& x) const { + const T kMin = std::numeric_limits::lowest(); + return x > 0 ? std::log(x) : kMin; + } +}; + +template +struct LogGradNegativeFunctor { + HOSTDEVICE T operator()(const T& x) const { + const T kMin = std::numeric_limits::lowest(); + return x < 0 ? std::log(-x) : kMin; + } +}; + +template +void LogcumsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* d_x) { + reverse = !reverse; + dev_ctx.template Alloc(d_x); + + auto eigen_x = EigenMatrix::From(x); + auto eigen_out = EigenMatrix::From(out); + auto eigen_d_out = EigenMatrix::From(d_out); + auto& place = *dev_ctx.eigen_device(); + + DenseTensor output_pos; + output_pos.Resize(d_out.dims()); + dev_ctx.template Alloc(&output_pos); + auto eigen_output_pos = EigenMatrix::From(output_pos); + DenseTensor output_neg; + output_neg.Resize(d_out.dims()); + dev_ctx.template Alloc(&output_neg); + auto eigen_output_neg = EigenMatrix::From(output_neg); + DenseTensor tmp; + tmp.Resize(d_out.dims()); + dev_ctx.template Alloc(&tmp); + auto eigen_tmp = EigenMatrix::From(tmp); + + eigen_tmp.device(place) = + eigen_d_out.unaryExpr(LogGradPositiveFunctor()) - eigen_out; + LogcumsumexpKernel( + dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos); + eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp(); + + eigen_tmp.device(place) = + eigen_d_out.unaryExpr(LogGradNegativeFunctor()) - eigen_out; + LogcumsumexpKernel( + dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); + eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); + + auto eigen_d_x = EigenMatrix::From(*d_x); + eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; +} +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index cb0135d9b4c29..8ef007a1a1bef 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -191,6 +191,7 @@ from .tensor.math import cosh # noqa: F401 from .tensor.math import cumsum # noqa: F401 from .tensor.math import cumprod # noqa: F401 +from .tensor.math import logcumsumexp # noqa: F401 from .tensor.math import logit # noqa: F401 from .tensor.math import exp # noqa: F401 from .tensor.math import expm1 # noqa: F401 @@ -403,6 +404,7 @@ 'eye', 'cumsum', 'cumprod', + 'logcumsumexp', 'logit', 'sign', 'is_empty', diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py new file mode 100644 index 0000000000000..00b0b961aa30e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -0,0 +1,154 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from typing import Optional +import unittest +import itertools +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import _test_eager_guard + + +def np_logcumsumexp(x: np.ndarray, axis: Optional[int]=None, reverse: bool=False, exclusive: bool=False): + x = np.copy(x) + + if axis is None: + x = x.flatten() + axis = 0 + + if reverse: + x = np.flip(x, axis) + + dimensions = [range(dim) for dim in x.shape[:axis]] + + if exclusive: + x = np.roll(x, 1, axis) + for prefix_dim in itertools.product(*dimensions): + x[prefix_dim][0] = np.finfo(x.dtype).min + + for prefix_dim in itertools.product(*dimensions): + arr = x[prefix_dim] + for dim in range(1, arr.shape[0]): + arr[dim] = np.logaddexp(arr[dim - 1], arr[dim]) + + if reverse: + x = np.flip(x, axis) + + return x + + +class TestLogcumsumexpOp(unittest.TestCase): + def run_cases(self): + data_np = np.arange(12, dtype=np.float32).reshape(3, 4) + data = paddle.to_tensor(data_np) + + y = paddle.logcumsumexp(data) + z = np_logcumsumexp(data_np) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, axis=0) + z = np_logcumsumexp(data_np, axis=0) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, axis=-1) + z = np_logcumsumexp(data_np, axis=-1) + self.assertTrue(np.allclose(z, y.numpy())) + + y = paddle.logcumsumexp(data, dtype='float64') + self.assertTrue(y.dtype == core.VarDesc.VarType.FP64) + + y = paddle.logcumsumexp(data, axis=-2) + z = np_logcumsumexp(data_np, axis=-2) + self.assertTrue(np.allclose(z, y.numpy())) + + with self.assertRaises(IndexError): + y = paddle.logcumsumexp(data, axis=-3) + + with self.assertRaises(IndexError): + y = paddle.logcumsumexp(data, axis=2) + + def run_static(self, use_gpu=False): + with fluid.program_guard(fluid.Program()): + data_np = np.random.random((100, 100)).astype(np.float32) + x = paddle.static.data('X', [100, 100]) + y = paddle.logcumsumexp(x) + y2 = paddle.logcumsumexp(x, axis=0) + y3 = paddle.logcumsumexp(x, axis=-1) + y4 = paddle.logcumsumexp(x, dtype='float64') + y5 = paddle.logcumsumexp(x, axis=-2) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run(feed={'X': data_np}, + fetch_list=[ + y.name, y2.name, y3.name, y4.name, y5.name, + ]) + + z = np_logcumsumexp(data_np) + self.assertTrue(np.allclose(z, out[0])) + z = np_logcumsumexp(data_np, axis=0) + self.assertTrue(np.allclose(z, out[1])) + z = np_logcumsumexp(data_np, axis=-1) + self.assertTrue(np.allclose(z, out[2])) + self.assertTrue(out[3].dtype == np.float64) + z = np_logcumsumexp(data_np, axis=-2) + self.assertTrue(np.allclose(z, out[4])) + + def test_cpu(self): + paddle.disable_static(paddle.fluid.CPUPlace()) + self.run_cases() + paddle.enable_static() + + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + paddle.disable_static(paddle.fluid.CUDAPlace(0)) + self.run_cases() + paddle.enable_static() + + self.run_static(use_gpu=True) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = paddle.static.data('x', [3, 4]) + y = paddle.logcumsumexp(x, name='out') + self.assertTrue('out' in y.name) + + def test_type_error(self): + with fluid.program_guard(fluid.Program()): + + with self.assertRaises(TypeError): + data_np = np.random.random((100, 100), dtype=np.int32) + x = paddle.static.data('X', [100, 100], dtype='int32') + y = paddle.logcumsumexp(x) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + out = exe.run(feed={'X': data_np}, + fetch_list=[ + y.name + ]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 5f0fb4336e014..05a9ee6bd4e39 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -138,6 +138,7 @@ from .math import cosh # noqa: F401 from .math import cumsum # noqa: F401 from .math import cumprod # noqa: F401 +from .math import logcumsumexp # noqa: F401 from .math import logit # noqa: F401 from .math import exp # noqa: F401 from .math import exp_ # noqa: F401 @@ -306,6 +307,7 @@ 'cosh', 'cumsum', 'cumprod', + 'logcumsumexp', 'logit', 'exp', 'exp_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 59206eca81d4f..ea88704a4864d 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2969,6 +2969,34 @@ def cumsum(x, axis=None, dtype=None, name=None): _cum_sum_ = generate_layer_fn('cumsum') return _cum_sum_(**kwargs) + +def logcumsumexp(x, axis=None, dtype=None, name=None): + if axis is None: + flatten = True + else: + flatten = False + if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): + x = cast(x, dtype) + + if in_dygraph_mode(): + if axis is None: axis = -1 + return _C_ops.final_state_logcumsumexp(x, axis, flatten, False, False) + if _in_legacy_dygraph(): + if axis is None: + return _C_ops.logcumsumexp(x, 'flatten', flatten) + else: + return _C_ops.logcumsumexp(x, 'axis', axis, 'flatten', flatten) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp") + locals_var = locals().copy() + kwargs = dict() + for name, val in locals_var.items(): + if val is not None: + kwargs[name] = val + _logcumsumexp_ = generate_layer_fn('logcumsumexp') + return _logcumsumexp_(**kwargs) + + def cumprod(x, dim=None, dtype=None, name=None): """ Compute the cumulative product of the input tensor x along a given dimension dim. diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index d401e7c5190fe..b07d2262e9703 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -454,11 +454,20 @@ args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) output : Tensor(out) infer_meta : - func : CumsumInferMeta + func : CumInferMeta kernel : func : cumsum backward : cumsum_grad +- api : logcumsumexp + args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(out) + infer_meta : + func : CumInferMeta + kernel : + func : logcumsumexp + backward : logcumsumexp_grad + - api : deformable_conv args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 3b47470139b90..41cc33cae7c08 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -354,6 +354,16 @@ output : Tensor(x_grad) invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) +- backward_api : logcumsumexp_grad + forward : logcumsumexp(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + args : (Tensor x, Tensor out, Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(x_grad) + kernel : + func : logcumsumexp_grad + - backward_api : deformable_conv_grad forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) From 4054f7c1118a7f8185c76a852afab5d73c39d3a6 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 26 Apr 2022 14:46:05 +0800 Subject: [PATCH 02/21] polish --- paddle/phi/kernels/cpu/cum_kernel.cc | 13 +++++++------ paddle/phi/kernels/gpu/cum_kernel.cu | 13 ------------- paddle/phi/kernels/impl/logcumsumexp_grad_impl.h | 1 + 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 846d148bf1397..e44bc7b783190 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -48,14 +48,14 @@ void ComputeImp(Device d, } } -template +template void ScanKernel(const Context& dev_ctx, const DenseTensor& x, int axis, bool flatten, bool exclusive, bool reverse, - Op op, + Reducer reducer, DenseTensor* out) { auto out_dims = out->dims(); @@ -98,7 +98,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 0, reverse, exclusive, - op); + reducer); } else { ComputeImp(place, Eigen::DSizes(mid, post), @@ -107,7 +107,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 0, reverse, exclusive, - op); + reducer); } } else { if (post == 1) { @@ -118,7 +118,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 1, reverse, exclusive, - op); + reducer); } else { ComputeImp(place, Eigen::DSizes(pre, mid, post), @@ -127,7 +127,7 @@ void ScanKernel(const Context& dev_ctx, /* axis= */ 1, reverse, exclusive, - op); + reducer); } } } @@ -146,6 +146,7 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } +// Copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index db222cdaf0ffd..3846f832b6fa1 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -141,14 +141,6 @@ struct LogAddExp { } }; -struct Prod { - template - __host__ __device__ __forceinline__ T operator()(const T& a, - const T& b) const { - return a * b; - } -}; - template struct Identity; @@ -157,11 +149,6 @@ struct Identity { static constexpr T value = 0; }; -template -struct Identity { - static constexpr T value = 1; -}; - template struct Identity { static constexpr T value = std::numeric_limits::lowest(); diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index f62b027a2aa9d..7e3cd7295e5f9 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -47,6 +47,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x) { + // Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py reverse = !reverse; dev_ctx.template Alloc(d_x); From 1f98cc7ad7a9a923cd973b923d14cb758d9fa352 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 11:32:42 +0800 Subject: [PATCH 03/21] fix ci --- paddle/fluid/operators/{cumsum_op.cc => cum_op.cc} | 0 paddle/phi/kernels/cpu/cum_kernel.cc | 2 +- paddle/phi/kernels/{cumsum_kernel.h => cum_kernel.h} | 0 paddle/phi/kernels/gpu/cum_kernel.cu | 2 +- paddle/phi/kernels/impl/logcumsumexp_grad_impl.h | 2 +- 5 files changed, 3 insertions(+), 3 deletions(-) rename paddle/fluid/operators/{cumsum_op.cc => cum_op.cc} (100%) rename paddle/phi/kernels/{cumsum_kernel.h => cum_kernel.h} (100%) diff --git a/paddle/fluid/operators/cumsum_op.cc b/paddle/fluid/operators/cum_op.cc similarity index 100% rename from paddle/fluid/operators/cumsum_op.cc rename to paddle/fluid/operators/cum_op.cc diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index e44bc7b783190..3519acd7c8053 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/cumsum_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/cumsum_kernel.h b/paddle/phi/kernels/cum_kernel.h similarity index 100% rename from paddle/phi/kernels/cumsum_kernel.h rename to paddle/phi/kernels/cum_kernel.h diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 016e3dcff10db..7ec60fa11cfb8 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -17,7 +17,7 @@ #include #include -#include "paddle/phi/kernels/cumsum_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" #ifdef __NVCC__ #include #endif diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 7e3cd7295e5f9..3a5b5d2d15c80 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -16,7 +16,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/cumsum_kernel.h" +#include "paddle/phi/kernels/cum_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { From 518c75a2f24d7f6266d792096da2a0aac1bbba30 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 15:51:00 +0800 Subject: [PATCH 04/21] reformat --- paddle/phi/kernels/cpu/cum_kernel.cc | 11 ++++------ .../phi/kernels/impl/logcumsumexp_grad_impl.h | 3 ++- paddle/utils/variant.h | 9 +++++++++ .../tests/unittests/test_logcumsumexp_op.py | 18 ++++++++++------- python/paddle/utils/code_gen/api.yaml | 18 ++++++++--------- python/paddle/utils/code_gen/backward.yaml | 20 +++++++++---------- 6 files changed, 45 insertions(+), 34 deletions(-) diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 3519acd7c8053..3cb406f8b2907 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -146,7 +146,8 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } -// Copied from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h +// Copied from +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, @@ -268,9 +269,5 @@ PD_REGISTER_KERNEL(cumsum, int, int64_t) {} -PD_REGISTER_KERNEL(logcumsumexp, - CPU, - ALL_LAYOUT, - phi::LogcumsumexpKernel, - float, - double) {} +PD_REGISTER_KERNEL( + logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 3a5b5d2d15c80..9f82aca739ea3 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -47,7 +47,8 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x) { - // Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py + // Reference: + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py reverse = !reverse; dev_ctx.template Alloc(d_x); diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index a7546d094c2ff..7b11ae1bee88c 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -13,6 +13,11 @@ #pragma once +#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-copy" +#endif + /* variant synopsis @@ -2828,3 +2833,7 @@ struct hash { }; } // namespace std + +#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 +#pragma GCC diagnostic pop +#endif diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 00b0b961aa30e..21038b3e52730 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -25,9 +25,12 @@ from paddle.fluid.framework import _test_eager_guard -def np_logcumsumexp(x: np.ndarray, axis: Optional[int]=None, reverse: bool=False, exclusive: bool=False): +def np_logcumsumexp(x: np.ndarray, + axis: Optional[int]=None, + reverse: bool=False, + exclusive: bool=False): x = np.copy(x) - + if axis is None: x = x.flatten() axis = 0 @@ -98,7 +101,11 @@ def run_static(self, use_gpu=False): exe.run(fluid.default_startup_program()) out = exe.run(feed={'X': data_np}, fetch_list=[ - y.name, y2.name, y3.name, y4.name, y5.name, + y.name, + y2.name, + y3.name, + y4.name, + y5.name, ]) z = np_logcumsumexp(data_np) @@ -144,10 +151,7 @@ def test_type_error(self): place = fluid.CUDAPlace(0) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - out = exe.run(feed={'X': data_np}, - fetch_list=[ - y.name - ]) + out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) if __name__ == '__main__': diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 50c81a9ea6a19..103f8e942967e 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -459,15 +459,6 @@ func : cumsum backward : cumsum_grad -- api : logcumsumexp - args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) - output : Tensor(out) - infer_meta : - func : CumInferMeta - kernel : - func : logcumsumexp - backward : logcumsumexp_grad - - api : deformable_conv args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) output : Tensor(out) @@ -1199,6 +1190,15 @@ func : log_softmax backward : log_softmax_grad +- api : logcumsumexp + args : (Tensor x, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(out) + infer_meta : + func : CumInferMeta + kernel : + func : logcumsumexp + backward : logcumsumexp_grad + # logical_and - api : logical_and args : (Tensor x, Tensor y) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 1435f90166b8a..941f2a2faface 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -367,16 +367,6 @@ output : Tensor(x_grad) invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) -- backward_api : logcumsumexp_grad - forward : logcumsumexp(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) - infer_meta : - func : UnchangedInferMeta - param : [x] - args : (Tensor x, Tensor out, Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) - output : Tensor(x_grad) - kernel : - func : logcumsumexp_grad - - backward_api : deformable_conv_grad forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) @@ -944,6 +934,16 @@ kernel : func : log_softmax_grad +- backward_api : logcumsumexp_grad + forward : logcumsumexp(Tensor x, int axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + args : (Tensor x, Tensor out, Tensor out_grad, int axis, bool flatten, bool exclusive, bool reverse) + output : Tensor(x_grad) + kernel : + func : logcumsumexp_grad + - backward_api : logit_grad forward : logit (Tensor x, float eps = 1e-6f) -> Tensor(out) args : (Tensor x, Tensor out_grad, float eps) From e94f42c001a0450a5fca2c3cc476e8d5ad0be048 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 3 May 2022 21:48:45 +0800 Subject: [PATCH 05/21] update --- paddle/fluid/operators/cum_op.cc | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 592adf8971c67..7043d47a26b1e 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -26,19 +26,6 @@ class CumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; }; -class CumGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logsumexp"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logsumexp"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "logsumexp"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } -}; - class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -115,6 +102,19 @@ the input. If exlusive is true, the first element of the result is the minimum v } }; +class LogcumsumexpGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "logcumsumexp"); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + template class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { public: @@ -142,6 +142,8 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor, PD_INFER_META(phi::CumInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor, + PD_INFER_META(phi::CumInferMeta)); REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker, ops::CumsumGradMaker, @@ -149,8 +151,8 @@ REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker, ops::LogcumsumexpGradMaker, ops::LogcumsumexpGradMaker, - CumsumInferShapeFunctor); -REGISTER_OPERATOR(logcumsumexp_grad, ops::CumGradOp); + LogcumsumexpInferShapeFunctor); +REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp); REGISTER_OP_VERSION(cumsum) .AddCheckpoint( From 8c680e697e0d4ec93511edeb226ec1b8f08efe23 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 9 May 2022 11:16:06 +0800 Subject: [PATCH 06/21] address reviews --- paddle/fluid/operators/cum_op.cc | 4 +- paddle/phi/kernels/cpu/cum_kernel.cc | 2 - paddle/phi/kernels/gpu/cum_kernel.cu | 31 ++++++++++ paddle/utils/variant.h | 9 --- .../tests/unittests/test_logcumsumexp_op.py | 57 +++++++++++++++++-- python/paddle/tensor/math.py | 41 +++++++++++++ 6 files changed, 127 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 7043d47a26b1e..c4e906c25d837 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( The cumulative sum of the elements along a given axis. By default, the first element of the result is the same of the first element of -the input. If exlusive is true, the first element of the result is 0. +the input. If exclusive is true, the first element of the result is 0. )DOC"); } }; @@ -97,7 +97,7 @@ class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis. By default, the first element of the result is the same of the first element of -the input. If exlusive is true, the first element of the result is the minimum value of dtype. +the input. If exclusive is true, the first element of the result is the the lowest finite value of the dtype of output tensor. )DOC"); } }; diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 3cb406f8b2907..85a6ea5d8be1b 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -146,8 +146,6 @@ void CumsumKernel(const Context& dev_ctx, dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out); } -// Copied from -// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h template struct LogSumExp { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 7ec60fa11cfb8..59cd4eb7abc59 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -241,6 +241,37 @@ void ScanKernel(const Context& dev_ctx, T* out_data = dev_ctx.template Alloc(out); const T* in_data = x.data(); + // Use thrust for parallel acceleration when the input size is equal to the + // length of the ‘axis’ dimension. + if (std::is_same::value && size == out_dims[axis]) { +#ifdef __HIPCC__ + const auto& policy = thrust::hip::par.on(dev_ctx.stream()); +#else + const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); +#endif + if (reverse) { + thrust::reverse_iterator> reversed_in( + thrust::device_pointer_cast(in_data) + size); + thrust::reverse_iterator> reversed_out( + thrust::device_pointer_cast(out_data) + size); + if (exclusive) { + thrust::exclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } else { + thrust::inclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } + } else { + if (exclusive) { + thrust::exclusive_scan(policy, in_data, in_data + size, out_data); + } else { + thrust::inclusive_scan(policy, in_data, in_data + size, out_data); + } + } + return; + } + + size_t height = 1; size_t width = 1; for (size_t i = 0; i <= axis; i++) { diff --git a/paddle/utils/variant.h b/paddle/utils/variant.h index 7b11ae1bee88c..a7546d094c2ff 100644 --- a/paddle/utils/variant.h +++ b/paddle/utils/variant.h @@ -13,11 +13,6 @@ #pragma once -#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-copy" -#endif - /* variant synopsis @@ -2833,7 +2828,3 @@ struct hash { }; } // namespace std - -#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9 -#pragma GCC diagnostic pop -#endif diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 21038b3e52730..28313674007dd 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -23,6 +23,11 @@ import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard from paddle.fluid.framework import _test_eager_guard +from op_test import OpTest + + +def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None): + return np.log(np.cumsum(np.exp(x), axis=axis)) def np_logcumsumexp(x: np.ndarray, @@ -56,8 +61,8 @@ def np_logcumsumexp(x: np.ndarray, return x -class TestLogcumsumexpOp(unittest.TestCase): - def run_cases(self): +class TestLogcumsumexp(unittest.TestCase): + def run_imperative(self): data_np = np.arange(12, dtype=np.float32).reshape(3, 4) data = paddle.to_tensor(data_np) @@ -86,6 +91,17 @@ def run_cases(self): with self.assertRaises(IndexError): y = paddle.logcumsumexp(data, axis=2) + data_np = np.arange(10000, 10024, dtype=np.float32) + data = paddle.to_tensor(data_np) + y = paddle.logcumsumexp(data) + z = np_naive_logcumsumexp(data_np) + # check that naive algorithm overflows + self.assertTrue(all(z == np.inf)) + z = np_logcumsumexp(data_np) + # check that our algorithm doesn't overflow + self.assertTrue(all(z != np.inf)) + self.assertTrue(np.allclose(z, y.numpy())) + def run_static(self, use_gpu=False): with fluid.program_guard(fluid.Program()): data_np = np.random.random((100, 100)).astype(np.float32) @@ -120,7 +136,7 @@ def run_static(self, use_gpu=False): def test_cpu(self): paddle.disable_static(paddle.fluid.CPUPlace()) - self.run_cases() + self.run_imperative() paddle.enable_static() self.run_static() @@ -129,7 +145,7 @@ def test_gpu(self): if not fluid.core.is_compiled_with_cuda(): return paddle.disable_static(paddle.fluid.CUDAPlace(0)) - self.run_cases() + self.run_imperative() paddle.enable_static() self.run_static(use_gpu=True) @@ -154,5 +170,38 @@ def test_type_error(self): out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) +class BaseOpTest(OpTest): + def setUp(self): + self.op_type = "logcumsumexp" + input, attrs = self.input_and_attrs() + self.inputs = {'X': input} + self.attrs = attrs + self.outputs = {'Out': np_logcumsumexp(input)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def input_and_attrs(self): + raise NotImplementedError() + + +def TestLogcumsumexpOp1(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True} + + +def TestLogcumsumexpOp2(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True} + + +def TestLogcumsumexpOp3(BaseOpTest): + def input_and_attrs(self): + return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2419494c11df4..7caf91556ab93 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2971,6 +2971,47 @@ def cumsum(x, axis=None, dtype=None, name=None): def logcumsumexp(x, axis=None, dtype=None, name=None): + """ + The the logarithm of the cumulative summation of the exponentiation of the elements along a given axis. + + **Note**: + The first element of the result is the same of the first element of the input. + + Args: + x (Tensor): The input tensor + axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. + dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the result of logcumsumexp operator. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.arange(12) + data = paddle.reshape(data, (3, 4)) + + y = paddle.logcumsumexp(data) + # [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934 + # 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ] + + y = paddle.logcumsumexp(data, axis=0) + # [[ 0. 1. 2. 3. ] + # [ 4.01815 5.01815 6.01815 7.01815 ] + # [ 8.018479 9.018479 10.018479 11.018479]] + + y = paddle.logcumsumexp(data, axis=-1) + # [[ 0. 1.3132617 2.4076061 3.4401898] + # [ 4. 5.3132615 6.407606 7.44019 ] + # [ 8. 9.313262 10.407606 11.440189 ]] + + y = paddle.logcumsumexp(data, dtype='float64') + print(y.dtype) + # paddle.float64 + """ if axis is None: flatten = True else: From a3e50da63cb44550944f79da850ccb0c17e4995f Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 9 May 2022 16:16:55 +0800 Subject: [PATCH 07/21] add OpTest --- paddle/fluid/operators/cum_op.cc | 10 ++-- paddle/phi/kernels/cpu/cum_kernel.cc | 2 +- .../kernels/cpu/logcumsumexp_grad_kernel.cc | 2 + .../kernels/gpu/logcumsumexp_grad_kernel.cu | 2 + .../phi/kernels/impl/logcumsumexp_grad_impl.h | 14 +++--- paddle/phi/kernels/logcumsumexp_grad_kernel.h | 33 +++++++++++++ paddle/phi/ops/compat/logcumsumexp_sig.cc | 38 +++++++++++++++ .../tests/unittests/test_logcumsumexp_op.py | 48 ++++++++++++------- 8 files changed, 119 insertions(+), 30 deletions(-) create mode 100644 paddle/phi/kernels/logcumsumexp_grad_kernel.h create mode 100644 paddle/phi/ops/compat/logcumsumexp_sig.cc diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index c4e906c25d837..1dce244ee925b 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -123,15 +123,17 @@ class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr grad_op) const override { grad_op->SetType("logcumsumexp_grad"); - grad_op->SetInput("X", this->OutputGrad("Out")); - grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis"))); grad_op->SetAttr("flatten", BOOST_GET_CONST(bool, this->GetAttr("flatten"))); - grad_op->SetAttr("reverse", - BOOST_GET_CONST(bool, this->GetAttr("reverse"))); grad_op->SetAttr("exclusive", BOOST_GET_CONST(bool, this->GetAttr("exclusive"))); + grad_op->SetAttr("reverse", + BOOST_GET_CONST(bool, this->GetAttr("reverse"))); } }; diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 85a6ea5d8be1b..cd171cc8fc5fc 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -195,7 +195,7 @@ struct LogSumExpReducer { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return -Eigen::NumTraits::infinity(); + return Eigen::NumTraits::lowest(); } template diff --git a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc index 7d18bec424030..17f28b411bcdd 100644 --- a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + #include #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu index cee562696cb40..9f4633a1e021a 100644 --- a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + #include #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 9f82aca739ea3..578a19bfbc950 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -52,23 +52,23 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, reverse = !reverse; dev_ctx.template Alloc(d_x); - auto eigen_x = EigenMatrix::From(x); - auto eigen_out = EigenMatrix::From(out); - auto eigen_d_out = EigenMatrix::From(d_out); + auto eigen_x = EigenVector::Flatten(x); + auto eigen_out = EigenVector::Flatten(out); + auto eigen_d_out = EigenVector::Flatten(d_out); auto& place = *dev_ctx.eigen_device(); DenseTensor output_pos; output_pos.Resize(d_out.dims()); dev_ctx.template Alloc(&output_pos); - auto eigen_output_pos = EigenMatrix::From(output_pos); + auto eigen_output_pos = EigenVector::Flatten(output_pos); DenseTensor output_neg; output_neg.Resize(d_out.dims()); dev_ctx.template Alloc(&output_neg); - auto eigen_output_neg = EigenMatrix::From(output_neg); + auto eigen_output_neg = EigenVector::Flatten(output_neg); DenseTensor tmp; tmp.Resize(d_out.dims()); dev_ctx.template Alloc(&tmp); - auto eigen_tmp = EigenMatrix::From(tmp); + auto eigen_tmp = EigenVector::Flatten(tmp); eigen_tmp.device(place) = eigen_d_out.unaryExpr(LogGradPositiveFunctor()) - eigen_out; @@ -82,7 +82,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); - auto eigen_d_x = EigenMatrix::From(*d_x); + auto eigen_d_x = EigenVector::Flatten(*d_x); eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; } } // namespace phi diff --git a/paddle/phi/kernels/logcumsumexp_grad_kernel.h b/paddle/phi/kernels/logcumsumexp_grad_kernel.h new file mode 100644 index 0000000000000..212ca24c52215 --- /dev/null +++ b/paddle/phi/kernels/logcumsumexp_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogcumsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* d_x); + +} + diff --git a/paddle/phi/ops/compat/logcumsumexp_sig.cc b/paddle/phi/ops/compat/logcumsumexp_sig.cc new file mode 100644 index 0000000000000..baf635f450267 --- /dev/null +++ b/paddle/phi/ops/compat/logcumsumexp_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LogcumsumexpOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp", + {"X"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"Out"}); +} + +KernelSignature LogcumsumexpGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp_grad", + {"X", "Out", "Out@GRAD"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, phi::LogcumsumexpGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 28313674007dd..9450fc3afa4f8 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -32,8 +32,14 @@ def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None): def np_logcumsumexp(x: np.ndarray, axis: Optional[int]=None, + flatten: Optional[bool]=None, reverse: bool=False, exclusive: bool=False): + # `flatten` aligns with c++ op + if flatten: + assert axis in [0, None] + axis = None + x = np.copy(x) if axis is None: @@ -170,37 +176,43 @@ def test_type_error(self): out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) -class BaseOpTest(OpTest): - def setUp(self): - self.op_type = "logcumsumexp" - input, attrs = self.input_and_attrs() - self.inputs = {'X': input} - self.attrs = attrs - self.outputs = {'Out': np_logcumsumexp(input)} +class BaseTestCases: + class BaseOpTest(OpTest): + def setUp(self): + self.op_type = "logcumsumexp" + input, attrs = self.input_and_attrs() + self.inputs = {'X': input} + self.attrs = attrs + self.outputs = {'Out': np_logcumsumexp(input, **attrs)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') - def test_check_output(self): - self.check_output() + def input_and_attrs(self): + raise NotImplementedError() - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): def input_and_attrs(self): - raise NotImplementedError() + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True} -def TestLogcumsumexpOp1(BaseOpTest): +class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1, 'reverse': True} -def TestLogcumsumexpOp2(BaseOpTest): +class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1} -def TestLogcumsumexpOp3(BaseOpTest): +class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True, 'exclusive': True} if __name__ == '__main__': From 0b4b8ca936c4113dc62245159d35e30f86fb7d10 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Fri, 13 May 2022 17:29:31 +0800 Subject: [PATCH 08/21] use user defined grad --- .../tests/unittests/test_logcumsumexp_op.py | 60 +++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 9450fc3afa4f8..ad382425e956f 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -67,6 +67,37 @@ def np_logcumsumexp(x: np.ndarray, return x +def np_logcumsumexp_grad( + x: np.ndarray, + dout: np.ndarray, + axis: Optional[int]=None, + flatten: Optional[bool]=None, + reverse: bool=False, + exclusive: bool=False, ): + # Reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py + out = np_logcumsumexp(x, axis, flatten, reverse, exclusive) + log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min) + log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min) + + output_pos = np.exp( + np_logcumsumexp( + log_grad_positive - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) + output_neg = np.exp( + np_logcumsumexp( + log_grad_negative - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) + + return output_pos - output_neg + + class TestLogcumsumexp(unittest.TestCase): def run_imperative(self): data_np = np.arange(12, dtype=np.float32).reshape(3, 4) @@ -189,7 +220,13 @@ def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[ + np_logcumsumexp_grad(self.inputs['X'], 1 / + self.inputs['X'].size, **self.attrs) + ]) def input_and_attrs(self): raise NotImplementedError() @@ -197,12 +234,21 @@ def input_and_attrs(self): class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True} + return np.arange( + 200, dtype=np.float64).reshape(20, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True + } class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1, 'reverse': True} + return np.arange( + 200, dtype=np.float64).reshape(20, 10), { + 'axis': 1, + 'reverse': True + } class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): @@ -212,7 +258,13 @@ def input_and_attrs(self): class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True, 'exclusive': True} + return np.arange( + 200, dtype=np.float64).reshape(20, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True, + 'exclusive': True + } if __name__ == '__main__': From 3bf4cfe31b88430cf52382ac317d4d7727039a6d Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Fri, 13 May 2022 17:43:35 +0800 Subject: [PATCH 09/21] add formula in docs, address reviews --- paddle/fluid/operators/cum_op.cc | 2 +- .../fluid/tests/unittests/test_logcumsumexp_op.py | 4 ++-- python/paddle/tensor/math.py | 10 ++++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index 1dce244ee925b..27a648872e7a5 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -97,7 +97,7 @@ class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis. By default, the first element of the result is the same of the first element of -the input. If exclusive is true, the first element of the result is the the lowest finite value of the dtype of output tensor. +the input. If exclusive is true, the first element of the result is the lowest finite value of the dtype of output tensor. )DOC"); } }; diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index ad382425e956f..35280303c9a38 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -115,8 +115,8 @@ def run_imperative(self): z = np_logcumsumexp(data_np, axis=-1) self.assertTrue(np.allclose(z, y.numpy())) - y = paddle.logcumsumexp(data, dtype='float64') - self.assertTrue(y.dtype == core.VarDesc.VarType.FP64) + y = paddle.logcumsumexp(data, dtype='float32') + self.assertTrue(y.dtype == core.VarDesc.VarType.FP32) y = paddle.logcumsumexp(data, axis=-2) z = np_logcumsumexp(data_np, axis=-2) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 7caf91556ab93..bc471af9b2b8c 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2971,9 +2971,15 @@ def cumsum(x, axis=None, dtype=None, name=None): def logcumsumexp(x, axis=None, dtype=None, name=None): - """ - The the logarithm of the cumulative summation of the exponentiation of the elements along a given axis. + r""" + The logarithm of the cumulative summation of the exponentiation of the elements along a given axis. + + For summation index j given by dim and other indices i, the result is + .. math:: + + logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij}) + **Note**: The first element of the result is the same of the first element of the input. From 34f57f1d58455473e1e213ac4f71a77ae010bd47 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Sat, 14 May 2022 17:11:10 +0800 Subject: [PATCH 10/21] remove 'reference' comment --- python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 35280303c9a38..02bd46ad26526 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -74,8 +74,6 @@ def np_logcumsumexp_grad( flatten: Optional[bool]=None, reverse: bool=False, exclusive: bool=False, ): - # Reference: - # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py out = np_logcumsumexp(x, axis, flatten, reverse, exclusive) log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min) log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min) From 661bff30b2526e2020f8248465ad143f1026658e Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Sat, 14 May 2022 19:30:00 +0800 Subject: [PATCH 11/21] Update logcumsumexp_grad_kernel.h --- paddle/phi/kernels/logcumsumexp_grad_kernel.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/kernels/logcumsumexp_grad_kernel.h b/paddle/phi/kernels/logcumsumexp_grad_kernel.h index 212ca24c52215..e78a79550657e 100644 --- a/paddle/phi/kernels/logcumsumexp_grad_kernel.h +++ b/paddle/phi/kernels/logcumsumexp_grad_kernel.h @@ -28,6 +28,4 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x); - } - From 30241bbcb67381701852e4eaab584025b5ddb422 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Sat, 14 May 2022 19:30:28 +0800 Subject: [PATCH 12/21] Update logcumsumexp_sig.cc --- paddle/phi/ops/compat/logcumsumexp_sig.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/ops/compat/logcumsumexp_sig.cc b/paddle/phi/ops/compat/logcumsumexp_sig.cc index baf635f450267..2c790903b6333 100644 --- a/paddle/phi/ops/compat/logcumsumexp_sig.cc +++ b/paddle/phi/ops/compat/logcumsumexp_sig.cc @@ -35,4 +35,5 @@ KernelSignature LogcumsumexpGradOpArgumentMapping( } // namespace phi PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, phi::LogcumsumexpGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, + phi::LogcumsumexpGradOpArgumentMapping); From 2454012147a99af94a9167f99612cce7a9ef9b36 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Sat, 14 May 2022 19:32:46 +0800 Subject: [PATCH 13/21] Update logcumsumexp_grad_impl.h --- paddle/phi/kernels/impl/logcumsumexp_grad_impl.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 578a19bfbc950..602f2248902cc 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -47,8 +47,6 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, bool exclusive, bool reverse, DenseTensor* d_x) { - // Reference: - // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/math_grad.py reverse = !reverse; dev_ctx.template Alloc(d_x); From 17344408d69f10e9fe5cf3200be1e381bc454694 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 16 May 2022 15:52:56 +0800 Subject: [PATCH 14/21] decrease input size, update python --- .../fluid/tests/unittests/test_logcumsumexp_op.py | 8 ++++---- python/paddle/tensor/math.py | 12 +++++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 02bd46ad26526..615e7c216146a 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -233,7 +233,7 @@ def input_and_attrs(self): class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): def input_and_attrs(self): return np.arange( - 200, dtype=np.float64).reshape(20, 10), { + 100, dtype=np.float64).reshape(10, 10), { 'axis': 0, 'flatten': True, 'reverse': True @@ -243,7 +243,7 @@ def input_and_attrs(self): class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): def input_and_attrs(self): return np.arange( - 200, dtype=np.float64).reshape(20, 10), { + 100, dtype=np.float64).reshape(10, 10), { 'axis': 1, 'reverse': True } @@ -251,13 +251,13 @@ def input_and_attrs(self): class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1} + return np.arange(100, dtype=np.float64).reshape(10, 10), {'axis': 1} class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): def input_and_attrs(self): return np.arange( - 200, dtype=np.float64).reshape(20, 10), { + 100, dtype=np.float64).reshape(10, 10), { 'axis': 0, 'flatten': True, 'reverse': True, diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bc471af9b2b8c..cf5143d08caf0 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3035,13 +3035,11 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): return _C_ops.logcumsumexp(x, 'axis', axis, 'flatten', flatten) check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp") - locals_var = locals().copy() - kwargs = dict() - for name, val in locals_var.items(): - if val is not None: - kwargs[name] = val - _logcumsumexp_ = generate_layer_fn('logcumsumexp') - return _logcumsumexp_(**kwargs) + + helper = LayerHelper('logcumsumexp', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='logcumsumexp', inputs={'X': x}, outputs={'Out': out}, attrs={'axis': axis, 'flatten': flatten}) + return out def cumprod(x, dim=None, dtype=None, name=None): From 3e4953a9d5589c1f62eb3a1ddf639852ce633fda Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 17 May 2022 09:27:51 +0800 Subject: [PATCH 15/21] shrink test data size --- python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 615e7c216146a..9c2668eb1a42f 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -139,8 +139,8 @@ def run_imperative(self): def run_static(self, use_gpu=False): with fluid.program_guard(fluid.Program()): - data_np = np.random.random((100, 100)).astype(np.float32) - x = paddle.static.data('X', [100, 100]) + data_np = np.random.random((5, 4)).astype(np.float32) + x = paddle.static.data('X', [5, 4]) y = paddle.logcumsumexp(x) y2 = paddle.logcumsumexp(x, axis=0) y3 = paddle.logcumsumexp(x, axis=-1) From 790e6165b0f37f7af12473b0b81e00668f83cfbb Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 17 May 2022 10:47:25 +0800 Subject: [PATCH 16/21] fix sample code --- python/paddle/tensor/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cf5143d08caf0..62315de57c180 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2997,7 +2997,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): import paddle - data = paddle.arange(12) + data = paddle.arange(12, dtype='float64') data = paddle.reshape(data, (3, 4)) y = paddle.logcumsumexp(data) From 9797d1006b3aa8e35ae8b589b2a0589f66a7d69e Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Wed, 18 May 2022 12:02:56 +0800 Subject: [PATCH 17/21] refine docs --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 62315de57c180..b0fee7aaec02a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2980,11 +2980,11 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij}) - **Note**: + Note: The first element of the result is the same of the first element of the input. Args: - x (Tensor): The input tensor + x (Tensor): The input tensor. axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. From 3b4b8fef180dc7b15584fe784c145c8078a8e256 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Wed, 25 May 2022 21:00:33 +0800 Subject: [PATCH 18/21] update docs --- python/paddle/tensor/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b0fee7aaec02a..a60679a2d1ddb 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2908,7 +2908,7 @@ def cumsum(x, axis=None, dtype=None, name=None): The cumulative sum of the elements along a given axis. **Note**: - The first element of the result is the same of the first element of the input. + The first element of the result is the same as the first element of the input. Args: x (Tensor): The input tensor needed to be cumsumed. @@ -2974,14 +2974,14 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): r""" The logarithm of the cumulative summation of the exponentiation of the elements along a given axis. - For summation index j given by dim and other indices i, the result is + For summation index j given by `axis` and other indices i, the result is .. math:: logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij}) - Note: - The first element of the result is the same of the first element of the input. + **Note**: + The first element of the result is the same as the first element of the input. Args: x (Tensor): The input tensor. From 57bb7117fb7e0babcc05dd97ba2be4130b492b37 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Fri, 27 May 2022 17:03:44 +0800 Subject: [PATCH 19/21] fix docs;test=document_fix --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index a60679a2d1ddb..d513b83d3e39e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2980,8 +2980,8 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): logcumsumexp(x)_{ij} = log \sum_{i=0}^{j}exp(x_{ij}) - **Note**: - The first element of the result is the same as the first element of the input. + Note: + The first element of the result is the same as the first element of the input. Args: x (Tensor): The input tensor. From 6a3647c98c5de3d67f9bc7d53e2e2b9a4d53c194 Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Tue, 7 Jun 2022 14:55:22 +0800 Subject: [PATCH 20/21] set test timeout to 30s --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 214c68c250ea9..f5681a8669bb9 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -684,6 +684,7 @@ endif() foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) +set_tests_properties(test_logcumsumexp_op PROPERTIES TIMEOUT 30) py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4) if(WITH_GPU From 13edc4fc422ef717f624b1ca11f75e0bd25aef2c Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Thu, 9 Jun 2022 15:10:40 +0800 Subject: [PATCH 21/21] reformat --- .../kernels/gpu/logcumsumexp_grad_kernel.cu | 3 +- .../tests/unittests/test_logcumsumexp_op.py | 99 ++++++++++--------- 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu index 9f4633a1e021a..43744210e32b7 100644 --- a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" - #include #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" PD_REGISTER_KERNEL(logcumsumexp_grad, GPU, diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 9c2668eb1a42f..ebc350d13c673 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -26,15 +26,15 @@ from op_test import OpTest -def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None): +def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int] = None): return np.log(np.cumsum(np.exp(x), axis=axis)) def np_logcumsumexp(x: np.ndarray, - axis: Optional[int]=None, - flatten: Optional[bool]=None, - reverse: bool=False, - exclusive: bool=False): + axis: Optional[int] = None, + flatten: Optional[bool] = None, + reverse: bool = False, + exclusive: bool = False): # `flatten` aligns with c++ op if flatten: assert axis in [0, None] @@ -68,35 +68,35 @@ def np_logcumsumexp(x: np.ndarray, def np_logcumsumexp_grad( - x: np.ndarray, - dout: np.ndarray, - axis: Optional[int]=None, - flatten: Optional[bool]=None, - reverse: bool=False, - exclusive: bool=False, ): + x: np.ndarray, + dout: np.ndarray, + axis: Optional[int] = None, + flatten: Optional[bool] = None, + reverse: bool = False, + exclusive: bool = False, +): out = np_logcumsumexp(x, axis, flatten, reverse, exclusive) log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min) log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min) output_pos = np.exp( - np_logcumsumexp( - log_grad_positive - out, - axis=axis, - flatten=flatten, - reverse=not reverse, - exclusive=exclusive).reshape(x.shape) + x) + np_logcumsumexp(log_grad_positive - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) output_neg = np.exp( - np_logcumsumexp( - log_grad_negative - out, - axis=axis, - flatten=flatten, - reverse=not reverse, - exclusive=exclusive).reshape(x.shape) + x) + np_logcumsumexp(log_grad_negative - out, + axis=axis, + flatten=flatten, + reverse=not reverse, + exclusive=exclusive).reshape(x.shape) + x) return output_pos - output_neg class TestLogcumsumexp(unittest.TestCase): + def run_imperative(self): data_np = np.arange(12, dtype=np.float32).reshape(3, 4) data = paddle.to_tensor(data_np) @@ -206,7 +206,9 @@ def test_type_error(self): class BaseTestCases: + class BaseOpTest(OpTest): + def setUp(self): self.op_type = "logcumsumexp" input, attrs = self.input_and_attrs() @@ -218,51 +220,52 @@ def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad( - ['X'], - 'Out', - user_defined_grads=[ - np_logcumsumexp_grad(self.inputs['X'], 1 / - self.inputs['X'].size, **self.attrs) - ]) + self.check_grad(['X'], + 'Out', + user_defined_grads=[ + np_logcumsumexp_grad(self.inputs['X'], + 1 / self.inputs['X'].size, + **self.attrs) + ]) def input_and_attrs(self): raise NotImplementedError() class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): + def input_and_attrs(self): - return np.arange( - 100, dtype=np.float64).reshape(10, 10), { - 'axis': 0, - 'flatten': True, - 'reverse': True - } + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True + } class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): + def input_and_attrs(self): - return np.arange( - 100, dtype=np.float64).reshape(10, 10), { - 'axis': 1, - 'reverse': True - } + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 1, + 'reverse': True + } class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): + def input_and_attrs(self): return np.arange(100, dtype=np.float64).reshape(10, 10), {'axis': 1} class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): + def input_and_attrs(self): - return np.arange( - 100, dtype=np.float64).reshape(10, 10), { - 'axis': 0, - 'flatten': True, - 'reverse': True, - 'exclusive': True - } + return np.arange(100, dtype=np.float64).reshape(10, 10), { + 'axis': 0, + 'flatten': True, + 'reverse': True, + 'exclusive': True + } if __name__ == '__main__':