Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 2】16 新增 API RRelu (#41823)
Browse files Browse the repository at this point in the history
* rrelu逻辑部分

* unregistered op kernel (unresolved)

* commit before merge

* 丰富测试用例

* 修复rrelu-sig的bug

* 修复cpu环境测试

* 修改拼写错误

* 修改code format

* 尝试优化测试用例timeout的问题

* 优化测试用例

* 移除seed, 优化随机函数

* update en doc for rrelu

* fix rrelu en docs, test=document_fix

* add paper link for en docs, test=document_fix

* udpate en doc

* add r,test=document_fix
  • Loading branch information
thunder95 committed May 31, 2022
1 parent 6319dd8 commit 21e1d10
Show file tree
Hide file tree
Showing 17 changed files with 1,129 additions and 0 deletions.
126 changes: 126 additions & 0 deletions paddle/fluid/operators/rrelu_op.cc
@@ -0,0 +1,126 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

using framework::Tensor;

class RReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

class RReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of RReLU op.");
AddOutput("Out", "The output of RReLU op.");
AddOutput("Noise", "The random sampled RReLU noise.")
.AsIntermediate()
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
float default_lower = 1. / 8.;
AddAttr<float>("lower", "Lower bound of the uniform distribution.")
.SetDefault(default_lower)
.AddCustomChecker([](const float& lower) {
PADDLE_ENFORCE_EQ(lower >= 0.0f && lower < 1.0f, true,
platform::errors::InvalidArgument(
"'RRelu_lower' must be between 0.0 and 1.0."));
});
float defalut_upper = 1. / 3.;
AddAttr<float>("upper", "Upper bound of the uniform distribution.")
.SetDefault(defalut_upper)
.AddCustomChecker([](const float& upper) {
PADDLE_ENFORCE_EQ(upper > 0.0f && upper <= 1.0f, true,
platform::errors::InvalidArgument(
"'RRelu_upper' must be between 0.0 and 1.0."));
});
AddComment(R"DOC(
RReLU Operator.
Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
The function is defined as:
.. math::
\text{RReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \\
ax & \text{ otherwise }
\end{cases}
where :math:`a` is randomly sampled from uniform distribution
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
See: https://arxiv.org/pdf/1505.00853.pdf
)DOC");
}
};

class RReluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};

template <typename T>
class RReluGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("rrelu_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Noise", this->Output("Noise"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor,
PD_INFER_META(phi::RReluInferMeta));

REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker,
ops::RReluGradOpMaker<paddle::framework::OpDesc>,
ops::RReluGradOpMaker<paddle::imperative::OpBase>,
RReluInferShapeFunctor);

DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor,
PD_INFER_META(phi::RReluGradInferMeta));
REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor);
49 changes: 49 additions & 0 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -1977,6 +1977,55 @@ void RollInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
MetaTensor* out,
MetaTensor* noise) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(lower,
0,
phi::errors::InvalidArgument(
"The lower value should be greater than or equal to 0. "
"But received lower value = %f.",
lower));
PADDLE_ENFORCE_LE(upper,
1,
phi::errors::InvalidArgument(
"The upper value should be less than or equal to 1. "
"But received upper value = %f.",
upper));
PADDLE_ENFORCE_GE(
upper,
lower,
phi::errors::InvalidArgument(
"The upper value should be greater than or equal to lower value "
"But received upper value = %f, lower value = %f.",
upper,
lower));

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);

if (noise != nullptr) {
noise->set_dims(x_dims);
noise->set_dtype(x.dtype());
noise->set_layout(x.layout());
}
}

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad) {
auto do_dims = out_grad.dims();
x_grad->set_dims(do_dims);
x_grad->set_dtype(out_grad.dtype());
x_grad->share_lod(out_grad);
}

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -281,6 +281,17 @@ void RollInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
MetaTensor* out);

void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
MetaTensor* out,
MetaTensor* noise);

void RReluGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& noise,
MetaTensor* x_grad);

void SetValueInferMeta(const MetaTensor& x, MetaTensor* out);

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/kernels/cpu/rrelu_grad_kernel.cc
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rrelu_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void RReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& noise,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const T* n_ptr = noise.data<T>();
const T* x_ptr = x.data<T>();
const T* out_grad_ptr = out_grad.data<T>();
int numel = x.numel();
if (!x_grad) return;

int i = 0;
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
for (i = 0; i < numel; i++) {
x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i];
}
}

} // namespace phi

PD_REGISTER_KERNEL(
rrelu_grad, CPU, ALL_LAYOUT, phi::RReluGradKernel, float, double) {}
77 changes: 77 additions & 0 deletions paddle/phi/kernels/cpu/rrelu_kernel.cc
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rrelu_kernel.h"

#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void RReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const float lower,
const float upper,
bool is_test,
DenseTensor* out,
DenseTensor* noise) {
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
T* n_ptr = dev_ctx.template Alloc<T>(noise);
T zero = static_cast<T>(0);
int numel = x.numel();
int i = 0;

if (is_test) {
T mid_val = static_cast<T>((lower + upper) / 2.0);
for (i = 0; i < numel; i++) {
if (x_ptr[i] < zero) {
o_ptr[i] = mid_val * x_ptr[i];
n_ptr[i] = mid_val;
} else {
o_ptr[i] = x_ptr[i];
n_ptr[i] = 1.0;
}
}

return;
}

auto engine = paddle::framework::GetCPURandomEngine(0);

std::uniform_real_distribution<float> dist(lower, upper);

for (i = 0; i < numel; i++) {
if (x_ptr[i] < zero) {
T scale = static_cast<T>(dist(*engine));
o_ptr[i] = scale * x_ptr[i];
n_ptr[i] = scale;
} else {
o_ptr[i] = x_ptr[i];
n_ptr[i] = 1.0;
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(rrelu,
CPU,
ALL_LAYOUT,
phi::RReluKernel,
float,
phi::dtype::float16,
double) {}

0 comments on commit 21e1d10

Please sign in to comment.