diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc new file mode 100644 index 0000000000000..c543a088e9d7f --- /dev/null +++ b/paddle/fluid/operators/rrelu_op.cc @@ -0,0 +1,126 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class RReluOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class RReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of RReLU op."); + AddOutput("Out", "The output of RReLU op."); + AddOutput("Noise", "The random sampled RReLU noise.") + .AsIntermediate() + .AsExtra(); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + float default_lower = 1. / 8.; + AddAttr("lower", "Lower bound of the uniform distribution.") + .SetDefault(default_lower) + .AddCustomChecker([](const float& lower) { + PADDLE_ENFORCE_EQ(lower >= 0.0f && lower < 1.0f, true, + platform::errors::InvalidArgument( + "'RRelu_lower' must be between 0.0 and 1.0.")); + }); + float defalut_upper = 1. / 3.; + AddAttr("upper", "Upper bound of the uniform distribution.") + .SetDefault(defalut_upper) + .AddCustomChecker([](const float& upper) { + PADDLE_ENFORCE_EQ(upper > 0.0f && upper <= 1.0f, true, + platform::errors::InvalidArgument( + "'RRelu_upper' must be between 0.0 and 1.0.")); + }); + AddComment(R"DOC( +RReLU Operator. + +Applies the randomized leaky rectified liner unit function, element-wise, +as described in the paper: + +`Empirical Evaluation of Rectified Activations in Convolutional Network`_. + +The function is defined as: + +.. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + +where :math:`a` is randomly sampled from uniform distribution +:math:`\mathcal{U}(\text{lower}, \text{upper})`. + + See: https://arxiv.org/pdf/1505.00853.pdf + +)DOC"); + } +}; + +class RReluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +template +class RReluGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rrelu_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Noise", this->Output("Noise")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor, + PD_INFER_META(phi::RReluInferMeta)); + +REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, + ops::RReluGradOpMaker, + ops::RReluGradOpMaker, + RReluInferShapeFunctor); + +DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor, + PD_INFER_META(phi::RReluGradInferMeta)); +REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6c2956417a3a3..0c6db65168f8d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1917,6 +1917,55 @@ void RollInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + MetaTensor* out, + MetaTensor* noise) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE(lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE(upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GE( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); + + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); + + if (noise != nullptr) { + noise->set_dims(x_dims); + noise->set_dtype(x.dtype()); + noise->set_layout(x.layout()); + } +} + +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& noise, + MetaTensor* x_grad) { + auto do_dims = out_grad.dims(); + x_grad->set_dims(do_dims); + x_grad->set_dtype(out_grad.dtype()); + x_grad->share_lod(out_grad); +} + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { auto in_dims = x.dims(); PADDLE_ENFORCE_LT( diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 559857bd6ce9b..1ae63e28c9a51 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -273,6 +273,17 @@ void RollInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + MetaTensor* out, + MetaTensor* noise); + +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& noise, + MetaTensor* x_grad); + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc new file mode 100644 index 0000000000000..10b6c6b1a3ea8 --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const T* n_ptr = noise.data(); + const T* x_ptr = x.data(); + const T* out_grad_ptr = out_grad.data(); + int numel = x.numel(); + if (!x_grad) return; + + int i = 0; + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + for (i = 0; i < numel; i++) { + x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + rrelu_grad, CPU, ALL_LAYOUT, phi::RReluGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc new file mode 100644 index 0000000000000..4c6e30beddfa3 --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_kernel.h" + +#include "paddle/fluid/framework/generator.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + DenseTensor* out, + DenseTensor* noise) { + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + T* n_ptr = dev_ctx.template Alloc(noise); + T zero = static_cast(0); + int numel = x.numel(); + int i = 0; + + if (is_test) { + T mid_val = static_cast((lower + upper) / 2.0); + for (i = 0; i < numel; i++) { + if (x_ptr[i] < zero) { + o_ptr[i] = mid_val * x_ptr[i]; + n_ptr[i] = mid_val; + } else { + o_ptr[i] = x_ptr[i]; + n_ptr[i] = 1.0; + } + } + + return; + } + + auto engine = paddle::framework::GetCPURandomEngine(0); + + std::uniform_real_distribution dist(lower, upper); + + for (i = 0; i < numel; i++) { + if (x_ptr[i] < zero) { + T scale = static_cast(dist(*engine)); + o_ptr[i] = scale * x_ptr[i]; + n_ptr[i] = scale; + } else { + o_ptr[i] = x_ptr[i]; + n_ptr[i] = 1.0; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, + CPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + phi::dtype::float16, + double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu new file mode 100644 index 0000000000000..44dc31ed5d926 --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/prelu_funcs.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template +__global__ void RReluOpGradKernel(const T* x_ptr, + const T* noise_ptr, + const T* out_grad_ptr, + T* x_grad_ptr, + int numel) { + CUDA_KERNEL_LOOP(index, numel) { + T scale = noise_ptr[index]; + T x = x_ptr[index]; + T out_grad = out_grad_ptr[index]; + T zero = static_cast(0); + x_grad_ptr[index] = (x < zero) ? scale * out_grad : out_grad; + } +} + +template +class RReluOpGradFunctor { + public: + void operator()(gpuStream_t stream, + const T* x, + const T* noise, + const T* out_grad, + T* x_grad, + int numel) { + RReluOpGradKernel< + T><<>>( + x, noise, out_grad, x_grad, numel); + } +}; + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + if (!x_grad) return; + dev_ctx.template Alloc(x_grad); + + const T* x_ptr = x.data(); + const T* n_ptr = noise.data(); + const T* out_grad_ptr = out_grad.data(); + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + + int numel = x.numel(); + auto stream = dev_ctx.stream(); + + RReluOpGradFunctor rrelu_grad; + rrelu_grad(stream, x_ptr, n_ptr, out_grad_ptr, x_grad_ptr, numel); +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu_grad, + GPU, + ALL_LAYOUT, + phi::RReluGradKernel, + float, + phi::dtype::float16, + double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu new file mode 100644 index 0000000000000..39582d5872a70 --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -0,0 +1,112 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/rrelu_kernel.h" + +namespace phi { + +template +struct RReluTrainCudaFunctor { + public: + RReluTrainCudaFunctor(const T* in, T* out, T* noise) + : in_(in), out_(out), noise_(noise) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = noise_[idx] * x; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; +}; + +template +struct RReluTestCudaFunctor { + public: + RReluTestCudaFunctor(const T* in, T* out, T* noise, T mid_val) + : in_(in), out_(out), noise_(noise), mid_val_(mid_val) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = mid_val_ * x; + noise_[idx] = mid_val_; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; + T mid_val_; +}; + +template +void RReluKernel(const Context& ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + DenseTensor* out, + DenseTensor* noise) { + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + T* noise_data = ctx.template Alloc(noise); + auto size = x.numel(); + if (size <= 0) return; + + phi::funcs::ForRange for_range(ctx, size); + if (is_test) { + T mid_val = static_cast((lower + upper) / 2.0); + RReluTestCudaFunctor functor(x_data, out_data, noise_data, mid_val); + for_range(functor); + } else { + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(lower, upper); + funcs::distribution_and_transform(ctx, noise, dist, trans); + RReluTrainCudaFunctor functor(x_data, out_data, noise_data); + for_range(functor); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, + GPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + phi::dtype::float16, + double) {} diff --git a/paddle/phi/kernels/rrelu_grad_kernel.h b/paddle/phi/kernels/rrelu_grad_kernel.h new file mode 100644 index 0000000000000..b6172fca10e53 --- /dev/null +++ b/paddle/phi/kernels/rrelu_grad_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h new file mode 100644 index 0000000000000..8deb52daaae13 --- /dev/null +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const float lower, + const float upper, + bool is_test, + DenseTensor* out, + DenseTensor* noise); +} // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc new file mode 100644 index 0000000000000..00cd705a24076 --- /dev/null +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "rrelu", {"X"}, {"lower", "upper", "is_test"}, {"Out", "Noise"}); +} + +KernelSignature RReluGradGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "rrelu_grad", {"X", "Noise", "Out@GRAD"}, {}, {"X@GRAD"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(rrelu, phi::RReluOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, phi::RReluGradGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py new file mode 100644 index 0000000000000..9d33ce085b7f7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -0,0 +1,326 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +from op_test import OpTest +import paddle +import paddle.nn.functional as F +from paddle.fluid import dygraph + +paddle.seed(102) +np.random.seed(102) + + +def ref_rrelu(x, lower, upper): + x_t = x.copy() + alpha = (lower + upper) / 2.0 + return np.where(x_t <= 0, alpha * x_t, x_t) + + +def ref_rrelu_nn(x, lower, upper): + return ref_rrelu(x, lower, upper) + + +def check_output(input, output, lower, upper): + lower_res = np.where(input <= 0, lower * input, input) + upper_res = np.where(input <= 0, upper * input, input) + return (output <= lower_res).all() and (output >= upper_res).all() + + +class TestFunctionalRReluAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float64') + self.lower_0 = 0.05 + self.lower_1 = 0.1 + self.upper_0 = 0.25 + self.upper_1 = 0.33 + + self.places = [ + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() else fluid.CPUPlace() + ] + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 4, 5], dtype="float32") + res1 = F.rrelu( + x=input, lower=self.lower_0, upper=self.upper_0, training=False) + res2 = F.rrelu( + x=input, lower=self.lower_1, upper=self.upper_1, training=False) + in_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype("float32") + + res_np1 = ref_rrelu(in_np, self.lower_0, self.upper_0) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res1]) + + self.assertTrue(np.allclose(fetches[0], res_np1)) + + res_np2 = ref_rrelu(in_np, self.lower_1, self.upper_1) + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res2]) + self.assertTrue(np.allclose(fetches[0], res_np2)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_static_graph_functional(self): + '''test_static_graph_functional''' + + for place in self.places: + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=self.x_np.shape, dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=self.x_np.shape, dtype="float64") + out_1 = F.rrelu(x_1, self.lower_0, self.upper_0, training=False) + out_2 = F.rrelu(x_2, self.lower_1, self.upper_1, training=False) + out_3 = F.rrelu(x_2, self.lower_1, self.upper_1, training=True) + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=out_1, + use_prune=True) + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_2, + use_prune=True) + res_3 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_3, + use_prune=True) + + out_ref_1 = ref_rrelu(self.x_np, self.lower_0, self.upper_0) + out_ref_2 = ref_rrelu(self.x_np, self.lower_1, self.upper_1) + self.assertEqual(np.allclose(out_ref_1, res_1), True) + self.assertEqual(np.allclose(out_ref_2, res_2), True) + self.assertTrue( + check_output(self.x_np, res_3[0], self.lower_1, self.upper_1)) + + def test_static_graph_layer(self): + '''test_static_graph_layer''' + + for place in self.places: + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=self.x_np.shape, dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=self.x_np.shape, dtype="float64") + # init instance + rrelu_1 = paddle.nn.RReLU(self.lower_0, self.upper_0) + rrelu_2 = paddle.nn.RReLU(self.lower_1, self.upper_1) + out_1 = rrelu_1(x_1) + out_2 = rrelu_2(x_2) + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=out_1, + use_prune=True) + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_2, + use_prune=True) + + self.assertTrue( + check_output(self.x_np, res_1[0], self.lower_0, self.upper_0)) + self.assertTrue( + check_output(self.x_np, res_2[0], self.lower_1, self.upper_1)) + + def dygraph_check(self, lower, upper): + for place in self.places: + paddle.disable_static(place) + x = paddle.to_tensor(self.x_np) + out = F.rrelu(x, lower, upper, training=False) + out_ref = ref_rrelu(self.x_np, lower, upper) + self.assertEqual(np.allclose(out_ref, out), True) + paddle.enable_static() + + def test_dygraph_functional(self): + '''test_dygraph_functional''' + + self.dygraph_check(self.lower_0, self.upper_0) + self.dygraph_check(self.lower_1, self.upper_1) + + def test_dygraph_layer(self): + '''test_dygraph_layer''' + + for place in self.places: + paddle.disable_static(place=place) + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + result = rrelu(paddle.to_tensor(self.x_np)) + self.assertTrue( + check_output(self.x_np, + result.numpy(), self.lower_0, self.upper_0)) + paddle.enable_static() + + def test_dygraph(self): + for place in self.places: + paddle.disable_static(place=place) + with dygraph.guard(): + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + out_np = rrelu(paddle.to_tensor(self.x_np)) + self.assertTrue( + check_output(self.x_np, + out_np.numpy(), self.lower_0, self.upper_0)) + paddle.enable_static() + + def test_error_functional(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises( + TypeError, F.rrelu, x=1, lower=self.lower_0, upper=self.upper_0) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[2, 3], dtype='int32') + self.assertRaises( + TypeError, + F.rrelu, + x=x_int32, + lower=self.lower_0, + upper=self.upper_0) + x_bool = paddle.fluid.data( + name='x_bool', shape=[2, 3], dtype='int32') + self.assertRaises( + TypeError, + F.rrelu, + x=x_bool, + lower=self.lower_0, + upper=self.upper_0) + # lower and upper must be float + x_fp32 = paddle.fluid.data( + name='x_fp32', shape=[2, 3], dtype='float32') + self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0, upper=0.5) + self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0.5, upper=1) + # lower and upper must be in (0, 1) + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=-1., upper=0.5) + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=2.) + # upper should not be less than lower + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=0.2) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[2, 3], dtype='float16') + F.rrelu(x=x_fp16, lower=self.lower_0, upper=self.upper_0) + + def test_error_layer(self): + def error_int_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float64") + rrelu = paddle.nn.RReLU(2, 3) + rrelu(paddle.to_tensor(x)) + + def error_lower_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0, 0.5) + rrelu(paddle.to_tensor(x)) + + def error_upper_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 1) + rrelu(paddle.to_tensor(x)) + + def error_lower_range(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(-1.0, 0.5) + rrelu(paddle.to_tensor(x)) + + def error_upper_range(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 2.0) + rrelu(paddle.to_tensor(x)) + + def error_lower_upper(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 0.2) + rrelu(paddle.to_tensor(x)) + + self.assertRaises(TypeError, error_int_dtype) + self.assertRaises(TypeError, error_lower_dtype) + self.assertRaises(TypeError, error_upper_dtype) + self.assertRaises(ValueError, error_lower_range) + self.assertRaises(ValueError, error_upper_range) + self.assertRaises(ValueError, error_lower_upper) + + +class RReluTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.lower = 0.1 + self.upper = 0.3 + self.is_test = True + self.init_prams() + + def init_prams(self): + self.dtype = "float64" + self.x_shape = [2, 3, 4, 5] + + x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) + out_np = ref_rrelu(x_np, self.lower, self.upper) + noise_np = np.ones(self.x_shape).astype(self.dtype) + noise_np[x_np < 0] = (self.lower + self.upper) / 2.0 + + self.inputs = {'X': x_np} + self.outputs = {'Out': out_np, 'Noise': noise_np} + self.attrs = { + 'lower': self.lower, + "upper": self.upper, + "is_test": self.is_test + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class RReluTrainingTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.lower = 0.3 + self.upper = 0.3000009 + self.is_test = False + self.init_prams() + + +class RReluTrainingTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.lower = 0.3 + self.upper = 0.3000009 + self.is_test = False + self.init_prams() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index bceee4b964a33..b4be291b0697f 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -51,6 +51,7 @@ from .layer.activation import ThresholdedReLU # noqa: F401 from .layer.activation import LogSoftmax # noqa: F401 from .layer.activation import Maxout # noqa: F401 +from .layer.activation import RReLU # noqa: F401 from .layer.common import Pad1D # noqa: F401 from .layer.common import Pad2D # noqa: F401 from .layer.common import ZeroPad2D # noqa: F401 @@ -313,4 +314,5 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', + 'RReLU', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 68213d831c550..44acf32894588 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -47,6 +47,7 @@ from .activation import log_softmax # noqa: F401 from .activation import glu # noqa: F401 from .activation import gumbel_softmax # noqa: F401 +from .activation import rrelu # noqa: F401 from .common import dropout # noqa: F401 from .common import dropout2d # noqa: F401 from .common import dropout3d # noqa: F401 @@ -228,4 +229,5 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'rrelu', ] diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index e64efda7b33bf..7af7ae8ddd7ab 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -548,6 +548,122 @@ def prelu(x, weight, data_format="NCHW", name=None): return out +def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): + r""" + rrelu activation. + + Applies the randomized leaky rectified liner unit function to improve generalization performance, + as described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network `_ + + During training, randomly samples the negative slope for activation values as described below: + + .. math:: + + rrelu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + a * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), + + In the test phase, the negative slope will take the average value of :math:`lower` and :math:`upper`: + + .. math:: + + rrelu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + (lower + upper) * 0.5 * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`lower` and :math:`upper` are the bounds of uniform distribution. + + Parameters: + x (Tensor): The input Tensor with data type float16, float32, float64. + lower (float, optional): The lower bound of uniform distribution. Default: 0.125. + upper (float, optional): The upper bound of uniform distribution. Default: 0.333. + training (bool, optional): Current mode is in training or others. Default is True. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + :name: rrelu-example + + import paddle + import paddle.nn.functional as F + + input_tensor = paddle.to_tensor([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], dtype='float32') + + out = F.rrelu(input_tensor, 0.1, 0.3) + #[[[[-0.20000899 3. -0.8810822 5. ] + # [ 3. -0.55175185 5. -1.0776101 ] + # [-1.0680687 -1.9896201 8. 9. ]] + # [[ 1. -0.5238267 -0.65515125 4. ] + # [-1.3766339 6. 7. -2.3465784 ] + # [ 6. 7. 8. 9. ]]]] + """ + + if not in_dynamic_mode(): + check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], + 'rrelu') + + if not isinstance(lower, float) or not isinstance(upper, float): + raise TypeError( + "The lower and upper values must be float type. Received: lower {}, upper {}.". + format(lower, upper)) + + if lower < 0 or lower > 1: + raise ValueError( + "The lower value must be no less than zero or greater than one. Received: {}.". + format(lower)) + + if upper < lower: + raise ValueError( + "The upper value must be greater than lower value. Received: lower {}, upper {}.". + format(lower, upper)) + + if upper > 1: + raise ValueError( + "The upper value must be no greater than one. Received: {}.".format( + upper)) + + is_test = not training + + if _in_legacy_dygraph(): + out, noise = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test', + is_test) + return out + + helper = LayerHelper('rrelu', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + noise = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = {'lower': lower, 'upper': upper, 'is_test': is_test} + helper.append_op( + type='rrelu', + inputs={"X": x}, + outputs={"Out": out, + "Noise": noise}, + attrs=attrs) + return out + + def relu(x, name=None): """ relu activation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 31364f0281c8a..cca8c37645df6 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -26,6 +26,7 @@ from .activation import Sigmoid # noqa: F401 from .activation import Softmax # noqa: F401 from .activation import LogSoftmax # noqa: F401 +from .activation import RReLU # noqa: F401 from .activation import Softmax2D # noqa: F401 from .common import Bilinear # noqa: F401 from .common import Pad1D # noqa: F401 diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 7fd109843bede..1a3768e919042 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -436,6 +436,93 @@ def extra_repr(self): name_str) +class RReLU(Layer): + r""" + RReLU activation layer. + + Applies the randomized leaky rectified liner unit function to improve generalization performance, + as described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network `_ + + During training, randomly samples the negative slope for activation values as described below: + + .. math:: + + RReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + a * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), + + In the test phase, the negative slope will take the average value of :math:`lower` and :math:`upper`: + + .. math:: + + RReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + (lower + upper) * 0.5 * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`lower` and :math:`upper` are the bounds of uniform distribution. + + Parameters: + lower (float, optional): The lower bound of uniform distribution. Default: 0.125. + upper (float, optional): The upper bound of uniform distribution. Default: 0.333. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. Default dtype is float32. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + :name: RReLU-example + + import paddle + + input_tensor = paddle.to_tensor([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], dtype='float32') + + rrelu_layer = paddle.nn.RReLU(0.1, 0.3) + output = rrelu_layer(input_tensor) + #[[[[-0.20000899 3. -0.88108218 5. ] + # [ 3. -0.55175185 5. -1.07761011] + # [-1.06806871 -1.98962009 8. 9. ]] + # [[ 1. -0.52382672 -0.65515128 4. ] + # [-1.37663394 6. 7. -2.34657836] + # [ 6. 7. 8. 9. ]]]] + """ + + def __init__(self, lower=1. / 8., upper=1. / 3., name=None): + super(RReLU, self).__init__() + self._lower = lower + self._upper = upper + self._name = name + + def forward(self, x): + return F.rrelu( + x, lower=self._lower, upper=self._upper, training=self.training) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'lower={}, upper={}, training={}, dtype={}{}'.format( + self._lower, self._upper, self.training, self._dtype, name_str) + + class ReLU(Layer): """ ReLU Activation. diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 6067b40f0a7c1..95c5ecf713112 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -399,6 +399,7 @@ 'test_positive_negative_pair_op', 'test_precision_recall_op', 'test_prelu_op', + 'test_rrelu_op', 'test_prelu_mkldnn_op', 'test_print_op', 'test_prior_box_op',