Skip to content

Commit

Permalink
add OpTest
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaishaonvjituizi committed May 9, 2022
1 parent 442bc00 commit a3e50da
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 30 deletions.
10 changes: 6 additions & 4 deletions paddle/fluid/operators/cum_op.cc
Expand Up @@ -123,15 +123,17 @@ class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logcumsumexp_grad");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
BOOST_GET_CONST(bool, this->GetAttr("exclusive")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/cum_kernel.cc
Expand Up @@ -195,7 +195,7 @@ struct LogSumExpReducer {
}

EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return -Eigen::NumTraits<T>::infinity();
return Eigen::NumTraits<T>::lowest();
}

template <typename Packet>
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/logcumsumexp_grad_kernel.cc
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h"

#include <limits>

#include "paddle/phi/backends/cpu/cpu_context.h"
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/logcumsumexp_grad_kernel.h"

#include <limits>

#include "paddle/phi/backends/cpu/cpu_context.h"
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/kernels/impl/logcumsumexp_grad_impl.h
Expand Up @@ -52,23 +52,23 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
reverse = !reverse;
dev_ctx.template Alloc<T>(d_x);

auto eigen_x = EigenMatrix<T>::From(x);
auto eigen_out = EigenMatrix<T>::From(out);
auto eigen_d_out = EigenMatrix<T>::From(d_out);
auto eigen_x = EigenVector<T>::Flatten(x);
auto eigen_out = EigenVector<T>::Flatten(out);
auto eigen_d_out = EigenVector<T>::Flatten(d_out);
auto& place = *dev_ctx.eigen_device();

DenseTensor output_pos;
output_pos.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_pos);
auto eigen_output_pos = EigenMatrix<T>::From(output_pos);
auto eigen_output_pos = EigenVector<T>::Flatten(output_pos);
DenseTensor output_neg;
output_neg.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_neg);
auto eigen_output_neg = EigenMatrix<T>::From(output_neg);
auto eigen_output_neg = EigenVector<T>::Flatten(output_neg);
DenseTensor tmp;
tmp.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&tmp);
auto eigen_tmp = EigenMatrix<T>::From(tmp);
auto eigen_tmp = EigenVector<T>::Flatten(tmp);

eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradPositiveFunctor<T>()) - eigen_out;
Expand All @@ -82,7 +82,7 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg);
eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp();

auto eigen_d_x = EigenMatrix<T>::From(*d_x);
auto eigen_d_x = EigenVector<T>::Flatten(*d_x);
eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg;
}
} // namespace phi
33 changes: 33 additions & 0 deletions paddle/phi/kernels/logcumsumexp_grad_kernel.h
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void LogcumsumexpGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& d_out,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* d_x);

}

38 changes: 38 additions & 0 deletions paddle/phi/ops/compat/logcumsumexp_sig.cc
@@ -0,0 +1,38 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

KernelSignature LogcumsumexpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logcumsumexp",
{"X"},
{"axis", "flatten", "exclusive", "reverse"},
{"Out"});
}

KernelSignature LogcumsumexpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("logcumsumexp_grad",
{"X", "Out", "Out@GRAD"},
{"axis", "flatten", "exclusive", "reverse"},
{"X@GRAD"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(logcumsumexp, phi::LogcumsumexpOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(logcumsumexp_grad, phi::LogcumsumexpGradOpArgumentMapping);
48 changes: 30 additions & 18 deletions python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py
Expand Up @@ -32,8 +32,14 @@ def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None):

def np_logcumsumexp(x: np.ndarray,
axis: Optional[int]=None,
flatten: Optional[bool]=None,
reverse: bool=False,
exclusive: bool=False):
# `flatten` aligns with c++ op
if flatten:
assert axis in [0, None]
axis = None

x = np.copy(x)

if axis is None:
Expand Down Expand Up @@ -170,37 +176,43 @@ def test_type_error(self):
out = exe.run(feed={'X': data_np}, fetch_list=[y.name])


class BaseOpTest(OpTest):
def setUp(self):
self.op_type = "logcumsumexp"
input, attrs = self.input_and_attrs()
self.inputs = {'X': input}
self.attrs = attrs
self.outputs = {'Out': np_logcumsumexp(input)}
class BaseTestCases:
class BaseOpTest(OpTest):
def setUp(self):
self.op_type = "logcumsumexp"
input, attrs = self.input_and_attrs()
self.inputs = {'X': input}
self.attrs = attrs
self.outputs = {'Out': np_logcumsumexp(input, **attrs)}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')

def test_check_output(self):
self.check_output()
def input_and_attrs(self):
raise NotImplementedError()

def test_check_grad(self):
self.check_grad(['X'], 'Out')

class TestLogcumsumexpOp1(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
raise NotImplementedError()
return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True}


def TestLogcumsumexpOp1(BaseOpTest):
class TestLogcumsumexpOp2(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True}
return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1, 'reverse': True}


def TestLogcumsumexpOp2(BaseOpTest):
class TestLogcumsumexpOp3(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True}
return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 1}


def TestLogcumsumexpOp3(BaseOpTest):
class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False}
return np.arange(200, dtype=np.float64).reshape(20, 10), {'axis': 0, 'flatten': True, 'reverse': True, 'exclusive': True}


if __name__ == '__main__':
Expand Down

0 comments on commit a3e50da

Please sign in to comment.