Skip to content

Commit

Permalink
优化测试用例
Browse files Browse the repository at this point in the history
  • Loading branch information
thunder95 committed Apr 30, 2022
1 parent 90102c0 commit 28cd511
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 46 deletions.
27 changes: 5 additions & 22 deletions paddle/fluid/operators/rrelu_op.cc
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -102,26 +102,6 @@ where :math:`a` is randomly sampled from uniform distribution
class RReluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "rrelu");
OP_INOUT_CHECK(ctx->HasInput("Noise"), "Input", "Noise", "rrelu");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "rrelu");

auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), out_dims);
ctx->ShareLoD(framework::GradVarName("Out"),
/*->*/ framework::GradVarName("X"));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};

template <typename T>
Expand Down Expand Up @@ -150,4 +130,7 @@ REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker,
ops::RReluGradOpMaker<paddle::framework::OpDesc>,
ops::RReluGradOpMaker<paddle::imperative::OpBase>,
RReluInferShapeFunctor);
REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp);

DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor,
PD_INFER_META(phi::RReluGradInferMeta));
REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor);
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -1957,6 +1957,15 @@ void RReluInferMeta(const MetaTensor& x,
}
}

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
x_grad->set_dims(do_dims);
x_grad->set_dtype(out_grad.dtype());
x_grad->share_lod(out_grad);
}

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -282,6 +282,10 @@ void RReluInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* noise);

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad);

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/rrelu_kernel.cc
Expand Up @@ -38,8 +38,8 @@ void RReluKernel(const Context& dev_ctx,
int i = 0;

if (is_test) {
T mid_val = static_cast<T>((lower + upper) / 2.0);
for (i = 0; i < numel; i++) {
T mid_val = static_cast<T>((lower + upper) / 2.0);
if (x_ptr[i] < zero) {
o_ptr[i] = mid_val * x_ptr[i];
n_ptr[i] = mid_val;
Expand Down Expand Up @@ -71,4 +71,10 @@ void RReluKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(rrelu, CPU, ALL_LAYOUT, phi::RReluKernel, float, double) {}
PD_REGISTER_KERNEL(rrelu,
CPU,
ALL_LAYOUT,
phi::RReluKernel,
float,
phi::dtype::float16,
double) {}
55 changes: 38 additions & 17 deletions python/paddle/fluid/tests/unittests/test_rrelu_op.py
Expand Up @@ -95,6 +95,7 @@ def test_static_graph_functional(self):
name="x2", shape=self.x_np.shape, dtype="float64")
out_1 = F.rrelu(x_1, self.lower_0, self.upper_0, training=False)
out_2 = F.rrelu(x_2, self.lower_1, self.upper_1, training=False)
out_3 = F.rrelu(x_2, self.lower_1, self.upper_1, training=True)

exe = paddle.static.Executor(place=place)
res_1 = exe.run(fluid.default_main_program(),
Expand All @@ -105,11 +106,17 @@ def test_static_graph_functional(self):
feed={"x2": self.x_np},
fetch_list=out_2,
use_prune=True)
res_3 = exe.run(fluid.default_main_program(),
feed={"x2": self.x_np},
fetch_list=out_3,
use_prune=True)

out_ref_1 = ref_rrelu(self.x_np, self.lower_0, self.upper_0)
out_ref_2 = ref_rrelu(self.x_np, self.lower_1, self.upper_1)
self.assertEqual(np.allclose(out_ref_1, res_1), True)
self.assertEqual(np.allclose(out_ref_2, res_2), True)
self.assertTrue(
check_output(self.x_np, res_3[0], self.lower_1, self.upper_1))

def test_static_graph_layer(self):
'''test_static_graph_layer'''
Expand Down Expand Up @@ -267,10 +274,14 @@ def error_lower_upper():
class RReluTest(OpTest):
def setUp(self):
self.op_type = "rrelu"
self.init_boundary()
self.init_dtype()
self.init_input_shape()
self.init_attr()
self.lower = 0.1
self.upper = 0.3
self.is_test = True
self.init_prams()

def init_prams(self):
self.dtype = "float64"
self.x_shape = [2, 3, 4, 5]

x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
out_np = ref_rrelu(x_np, self.lower, self.upper)
Expand All @@ -279,19 +290,11 @@ def setUp(self):

self.inputs = {'X': x_np}
self.outputs = {'Out': out_np, 'Noise': noise_np}

def init_boundary(self):
self.lower = 0.1
self.upper = 0.3

def init_dtype(self):
self.dtype = "float64"

def init_input_shape(self):
self.x_shape = [2, 3, 4, 5]

def init_attr(self):
self.attrs = {'lower': self.lower, "upper": self.upper, "is_test": True}
self.attrs = {
'lower': self.lower,
"upper": self.upper,
"is_test": self.is_test
}

def test_check_output(self):
self.check_output()
Expand All @@ -300,5 +303,23 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class RReluTrainingTest(OpTest):
def setUp(self):
self.op_type = "rrelu"
self.lower = 0.3
self.upper = 0.3000009
self.is_test = False
self.init_prams()


class RReluTrainingTest(OpTest):
def setUp(self):
self.op_type = "rrelu"
self.lower = 0.3
self.upper = 0.3000009
self.is_test = False
self.init_prams()


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/activation.py
Expand Up @@ -590,8 +590,8 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None):
[[ 1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[ 6.0, 7.0, 8.0, 9.0]]]], 'float32')
x = paddle.to_tensor(data)
out = F.rrelu(x, 0.1, 0.3)
input_tensor = paddle.to_tensor(data)
out = F.rrelu(input_tensor, 0.1, 0.3)
#[[[[-0.20000899 3. -0.8810822 5. ]
# [ 3. -0.55175185 5. -1.0776101 ]
# [-1.0680687 -1.9896201 8. 9. ]]
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/nn/layer/activation.py
Expand Up @@ -478,9 +478,9 @@ class RReLU(Layer):
[[ 1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[ 6.0, 7.0, 8.0, 9.0]]]], 'float64')
x = paddle.to_tensor(data)
m = paddle.nn.RReLU(0.1, 0.3)
out = m(x)
input_tensor = paddle.to_tensor(data)
rrelu_layer = paddle.nn.RReLU(0.1, 0.3)
output = rrelu_layer(input_tensor)
#[[[[-0.20000899 3. -0.88108218 5. ]
# [ 3. -0.55175185 5. -1.07761011]
# [-1.06806871 -1.98962009 8. 9. ]]
Expand Down

0 comments on commit 28cd511

Please sign in to comment.