From a3e50da63cb44550944f79da850ccb0c17e4995f Mon Sep 17 00:00:00 2001 From: tiancaishaonvjituizi <452565578@qq.com> Date: Mon, 9 May 2022 16:16:55 +0800 Subject: [PATCH] add OpTest --- paddle/fluid/operators/cum_op.cc | 10 ++-- paddle/phi/kernels/cpu/cum_kernel.cc | 2 +- .../kernels/cpu/logcumsumexp_grad_kernel.cc | 2 + .../kernels/gpu/logcumsumexp_grad_kernel.cu | 2 + .../phi/kernels/impl/logcumsumexp_grad_impl.h | 14 +++--- paddle/phi/kernels/logcumsumexp_grad_kernel.h | 33 +++++++++++++ paddle/phi/ops/compat/logcumsumexp_sig.cc | 38 +++++++++++++++ .../tests/unittests/test_logcumsumexp_op.py | 48 ++++++++++++------- 8 files changed, 119 insertions(+), 30 deletions(-) create mode 100644 paddle/phi/kernels/logcumsumexp_grad_kernel.h create mode 100644 paddle/phi/ops/compat/logcumsumexp_sig.cc diff --git a/paddle/fluid/operators/cum_op.cc b/paddle/fluid/operators/cum_op.cc index c4e906c25d837..1dce244ee925b 100644 --- a/paddle/fluid/operators/cum_op.cc +++ b/paddle/fluid/operators/cum_op.cc @@ -123,15 +123,17 @@ class LogcumsumexpGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr grad_op) const override { grad_op->SetType("logcumsumexp_grad"); - grad_op->SetInput("X", this->OutputGrad("Out")); - grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis"))); grad_op->SetAttr("flatten", BOOST_GET_CONST(bool, this->GetAttr("flatten"))); - grad_op->SetAttr("reverse", - BOOST_GET_CONST(bool, this->GetAttr("reverse"))); grad_op->SetAttr("exclusive", BOOST_GET_CONST(bool, this->GetAttr("exclusive"))); + grad_op->SetAttr("reverse", + BOOST_GET_CONST(bool, this->GetAttr("reverse"))); } }; diff --git a/paddle/phi/kernels/cpu/cum_kernel.cc b/paddle/phi/kernels/cpu/cum_kernel.cc index 85a6ea5d8be1b..cd171cc8fc5fc 100644 --- a/paddle/phi/kernels/cpu/cum_kernel.cc +++ b/paddle/phi/kernels/cpu/cum_kernel.cc @@ -195,7 +195,7 @@ struct LogSumExpReducer { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { - return -Eigen::NumTraits::infinity(); + return Eigen::NumTraits::lowest(); } template diff --git a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc index 7d18bec424030..17f28b411bcdd 100644 --- a/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + #include #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu index cee562696cb40..9f4633a1e021a 100644 --- a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h" + #include #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 9f82aca739ea3..578a19bfbc950 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -52,23 +52,23 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, reverse = !reverse; dev_ctx.template Alloc(d_x); - auto eigen_x = EigenMatrix::From(x); - auto eigen_out = EigenMatrix::From(out); - auto eigen_d_out = EigenMatrix::From(d_out); + auto eigen_x = EigenVector::Flatten(x); + auto eigen_out = EigenVector::Flatten(out); + auto eigen_d_out = EigenVector::Flatten(d_out); auto& place = *dev_ctx.eigen_device(); DenseTensor output_pos; output_pos.Resize(d_out.dims()); dev_ctx.template Alloc(&output_pos); - auto eigen_output_pos = EigenMatrix::From(output_pos); + auto eigen_output_pos = EigenVector::Flatten(output_pos); DenseTensor output_neg; output_neg.Resize(d_out.dims()); dev_ctx.template Alloc(&output_neg); - auto eigen_output_neg = EigenMatrix::From(output_neg); + auto eigen_output_neg = EigenVector::Flatten(output_neg); DenseTensor tmp; tmp.Resize(d_out.dims()); dev_ctx.template Alloc(&tmp); - auto eigen_tmp = EigenMatrix::From(tmp); + auto eigen_tmp = EigenVector::Flatten(tmp); eigen_tmp.device(place) = eigen_d_out.unaryExpr(LogGradPositiveFunctor()) - eigen_out; @@ -82,7 +82,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); - auto eigen_d_x = EigenMatrix::From(*d_x); + auto eigen_d_x = EigenVector::Flatten(*d_x); eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; } } // namespace phi diff --git a/paddle/phi/kernels/logcumsumexp_grad_kernel.h b/paddle/phi/kernels/logcumsumexp_grad_kernel.h new file mode 100644 index 0000000000000..212ca24c52215 --- /dev/null +++ b/paddle/phi/kernels/logcumsumexp_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogcumsumexpGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& d_out, + int axis, + bool flatten, + bool exclusive, + bool reverse, + DenseTensor* d_x); + +} + diff --git a/paddle/phi/ops/compat/logcumsumexp_sig.cc b/paddle/phi/ops/compat/logcumsumexp_sig.cc new file mode 100644 index 0000000000000..baf635f450267 --- /dev/null +++ b/paddle/phi/ops/compat/logcumsumexp_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LogcumsumexpOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp", + {"X"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"Out"}); +} + +KernelSignature LogcumsumexpGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("logcumsumexp_grad", + {"X", "Out", "Out@GRAD"}, + {"axis", "flatten", "exclusive", "reverse"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, phi::LogcumsumexpGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 28313674007dd..9450fc3afa4f8 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -32,8 +32,14 @@ def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None): def np_logcumsumexp(x: np.ndarray, axis: Optional[int]=None, + flatten: Optional[bool]=None, reverse: bool=False, exclusive: bool=False): + # `flatten` aligns with c++ op + if flatten: + assert axis in [0, None] + axis = None + x = np.copy(x) if axis is None: @@ -170,37 +176,43 @@ def test_type_error(self): out = exe.run(feed={'X': data_np}, fetch_list=[y.name]) -class BaseOpTest(OpTest): - def setUp(self): - self.op_type = "logcumsumexp" - input, attrs = self.input_and_attrs() - self.inputs = {'X': input} - self.attrs = attrs - self.outputs = {'Out': np_logcumsumexp(input)} +class BaseTestCases: + class BaseOpTest(OpTest): + def setUp(self): + self.op_type = "logcumsumexp" + input, attrs = self.input_and_attrs() + self.inputs = {'X': input} + self.attrs = attrs + self.outputs = {'Out': np_logcumsumexp(input, **attrs)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') - def test_check_output(self): - self.check_output() + def input_and_attrs(self): + raise NotImplementedError() - def test_check_grad(self): - self.check_grad(['X'], 'Out') +class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest): def input_and_attrs(self): - raise NotImplementedError() + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True} -def TestLogcumsumexpOp1(BaseOpTest): +class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1, 'reverse': True} -def TestLogcumsumexpOp2(BaseOpTest): +class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1} -def TestLogcumsumexpOp3(BaseOpTest): +class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): def input_and_attrs(self): - return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False} + return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True, 'exclusive': True} if __name__ == '__main__':