From b4022f35538de5083f87a2efb8f6d36a5db93036 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Mon, 28 Mar 2022 16:56:00 +0000 Subject: [PATCH 01/16] =?UTF-8?q?rrelu=E9=80=BB=E8=BE=91=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 138 ++++++++++++++++++++ paddle/phi/infermeta/unary.cc | 34 +++++ paddle/phi/infermeta/unary.h | 6 + paddle/phi/kernels/cpu/rrelu_grad_kernel.cc | 43 ++++++ paddle/phi/kernels/cpu/rrelu_kernel.cc | 53 ++++++++ paddle/phi/kernels/gpu/rrelu_funcs.h | 87 ++++++++++++ paddle/phi/kernels/gpu/rrelu_grad_kernel.cu | 95 ++++++++++++++ paddle/phi/kernels/gpu/rrelu_kernel.cu | 47 +++++++ paddle/phi/kernels/rrelu_grad_kernel.h | 28 ++++ paddle/phi/kernels/rrelu_kernel.h | 28 ++++ paddle/phi/ops/compat/rrelu_sig.cc | 38 ++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/activation.py | 40 ++++++ python/paddle/nn/layer/activation.py | 18 +++ 15 files changed, 659 insertions(+) create mode 100644 paddle/fluid/operators/rrelu_op.cc create mode 100644 paddle/phi/kernels/cpu/rrelu_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/rrelu_kernel.cc create mode 100644 paddle/phi/kernels/gpu/rrelu_funcs.h create mode 100644 paddle/phi/kernels/gpu/rrelu_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/rrelu_kernel.cu create mode 100644 paddle/phi/kernels/rrelu_grad_kernel.h create mode 100644 paddle/phi/kernels/rrelu_kernel.h create mode 100644 paddle/phi/ops/compat/rrelu_sig.cc diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc new file mode 100644 index 0000000000000..b393c1a02e542 --- /dev/null +++ b/paddle/fluid/operators/rrelu_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class RReluOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class RReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of RRelu op."); + AddOutput("Out", "The output of RRelu op."); + AddOutput("Noise", "The random sampled RRelu noise.") + .AsIntermediate() + .AsExtra(); + + float default_lower = 1. / 8.; + AddAttr("lower", "Lower bound of the uniform distribution.") + .SetDefault(default_lower) + .AddCustomChecker([](const float& lower) { + PADDLE_ENFORCE_EQ(lower >= 0.0f && lower < 1.0f, true, + platform::errors::InvalidArgument( + "'RRelu_lower' must be between 0.0 and 1.0.")); + }); + float defalut_upper = 1. / 3.; + AddAttr("upper", "Upper bound of the uniform distribution.") + .SetDefault(defalut_upper) + .AddCustomChecker([](const float& upper) { + PADDLE_ENFORCE_EQ(upper > 0.0f && upper <= 1.0f, true, + platform::errors::InvalidArgument( + "'RRelu_upper' must be between 0.0 and 1.0.")); + }); + + AddComment(R"DOC( +RRelu Operator. + +Applies the randomized leaky rectified liner unit function, element-wise, +as described in the paper: + +`Empirical Evaluation of Rectified Activations in Convolutional Network`_. + +The function is defined as: + +.. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + +where :math:`a` is randomly sampled from uniform distribution +:math:`\mathcal{U}(\text{lower}, \text{upper})`. + + See: https://arxiv.org/pdf/1505.00853.pdf + +)DOC"); + } +}; + +class RReluOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Noise"), "Input", "Noise", "rrelu_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "rrelu_grad"); + + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class RReluGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("rrelu_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("Noise", this->Output("Noise")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(rrelu, RReluInferShapeFunctor, + PD_INFER_META(phi::RReluInferMeta)); + +REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, + ops::RReluGradOpMaker, + ops::RReluGradOpMaker, + RReluInferShapeFunctor); +REGISTER_OPERATOR(rrelu_grad, ops::RReluOpGrad); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 80503dd243092..efc923bfc81c4 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2031,6 +2031,40 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { out->set_dtype(DataType::INT64); } +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + MetaTensor* out, + MetaTensor* noise) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE( + lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE( + upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GT( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + noise->set_dims(x_dims); + noise->set_dtype(x.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 0322a18fc3153..6af02d8c48249 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -300,4 +300,10 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + MetaTensor* out, + MetaTensor* noise); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc new file mode 100644 index 0000000000000..c6c14510293fb --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const T* n_ptr = noise.data(); + const T* x_ptr = x.data(); + const T* out_grad_ptr = out_grad.data(); + int numel = x.numel(); + int i = 0; + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + for (i = 0; i < numel; i++) { + x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i]; + } + +} + +} // namespace phi + +PD_REGISTER_KERNEL( + rrelu_grad, CPU, ALL_LAYOUT, phi::RReluGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc new file mode 100644 index 0000000000000..2311d7b0112ca --- /dev/null +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const float lower, + const float upper, + DenseTensor* out, + DenseTensor* noise) { + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + T* n_ptr = dev_ctx.template Alloc(noise); + + std::uniform_real_distribution dist(lower, upper); + auto gen_ptr = dev_ctx.GetGenerator(); + auto engine = gen_ptr->GetCPUEngine(); + + int numel = x.numel(); + int i = 0; + for (i = 0; i < numel; i++) { + if (x_ptr[i] < 0) { + T scale = static_cast(dist(*engine)); + o_ptr[i] = scale * x_ptr[i]; + n_ptr[i] = scale; + } else { + o_ptr[i] = x_ptr[i]; + n_ptr[i] = 1.0; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, CPU, ALL_LAYOUT, phi::RReluKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_funcs.h b/paddle/phi/kernels/gpu/rrelu_funcs.h new file mode 100644 index 0000000000000..2aab88e58b48c --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_funcs.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/phi/kernels/funcs/math_function.h" +#include +#include + +namespace phi { + +#define CUDA_NUM_THREADS 1024 + +inline static int PADDLE_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void RReluElementWiseKernel(const T *input, + T *output, + T *noise, + const float& lower, + const float& upper, + size_t numel) { + CUDA_KERNEL_LOOP(index, numel) { + T x = input[index]; + T zero = static_cast(0); + + if (x < zero) { + thrust::minstd_rand rng; + rng.seed(0); + thrust::uniform_real_distribution dist(lower, upper); + rng.discard(index); + T scale = dist(rng); + output[index] = scale * x; + noise[index] = scale; + } else { + output[index] = x; + noise[index] = 1.0; + } + } +} + + +template +class RReluElementWiseDirectCUDAFunctor { + public: + void operator()(gpuStream_t stream, + const T *input, + T *output, + T *noise, + const float& lower, + const float& upper, + size_t numel); +}; + +template +void RReluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, + const T *input, + T *output, + T *noise, + const float& lower, + const float& upper, + size_t numel) { + RReluElementWiseKernel<<>>( + input, output, noise, lower, upper, numel); +} + +template class RReluElementWiseDirectCUDAFunctor; +template class RReluElementWiseDirectCUDAFunctor; +} // namespace phi diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu new file mode 100644 index 0000000000000..8c1a46f152ca5 --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/prelu_funcs.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template +__global__ void PReluOpGradKernel(const T* x_ptr, + const T* noise_ptr, + const T* out_grad_ptr, + T* x_grad_ptr, + int numel) { + CUDA_KERNEL_LOOP(index, numel) { + T scale = noise_ptr[index]; + T x = x_ptr[index]; + T out_grad = out_grad_ptr[index]; + T zero = static_cast(0); + x_grad_ptr[index] = (x < zero) ? scale * out_grad : out_grad; + } +} + +template +class RReluOpGradFunctor { + public: + void operator()(gpuStream_t stream, + const T* x, + const T* noise, + const T* out_grad, + T* x_grad, + int numel) { + PReluOpGradKernel< + T><<>>( + x, + noise, + out_grad, + x_grad, + numel); + } +}; + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + if (!x_grad) return; + dev_ctx.template Alloc(x_grad); + + const T* x_ptr = x.data(); + const T* n_ptr = noise.data(); + const T* out_grad_ptr = out_grad.data(); + if (!x_grad) return; + T* x_grad_ptr = dev_ctx.template Alloc(x_grad); + + int numel = x.numel(); + auto stream = dev_ctx.stream(); + + RReluOpGradFunctor rrelu_grad; + rrelu_grad(stream, + x_ptr, + n_ptr, + out_grad_ptr, + x_grad_ptr, + numel); +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu_grad, + GPU, + ALL_LAYOUT, + phi::RReluGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu new file mode 100644 index 0000000000000..8536c1d4e4a6a --- /dev/null +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rrelu_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/rrelu_funcs.h" + +namespace phi { + +template +void RReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const float lower, + const float upper, + DenseTensor* out, + DenseTensor* noise) { + const T* x_ptr = x.data(); + T* o_ptr = dev_ctx.template Alloc(out); + T* n_ptr = dev_ctx.template Alloc(noise); + + int numel = x.numel(); + auto dim = x.dims(); + RReluElementWiseDirectCUDAFunctor rrelu_element_wise; + rrelu_element_wise(dev_ctx.stream(), x_ptr, o_ptr, n_ptr, lower, upper, numel); +} + +} // namespace phi + +PD_REGISTER_KERNEL(rrelu, + GPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + double) {} diff --git a/paddle/phi/kernels/rrelu_grad_kernel.h b/paddle/phi/kernels/rrelu_grad_kernel.h new file mode 100644 index 0000000000000..dbb8da874e27e --- /dev/null +++ b/paddle/phi/kernels/rrelu_grad_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& noise, + const DenseTensor& out_grad, + DenseTensor* x_grad); +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h new file mode 100644 index 0000000000000..92a61ed15b6f7 --- /dev/null +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RReluKernel(const Context& dev_ctx, + const DenseTensor& x, + const float lower, + const float upper, + DenseTensor* out, + DenseTensor* noise); +} // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc new file mode 100644 index 0000000000000..63043e499d78d --- /dev/null +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RReluOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "rrelu", {"X"}, {"lower", "upper"}, {"Out", "Noise"}); +} + +KernelSignature RReluGradGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("rrelu_grad", + {GradVarName("Out"), "Noise"}, + {}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(rrelu, + phi::RReluOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, + phi::RReluGradGradOpArgumentMapping); \ No newline at end of file diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c0820e140268b..76e966e7cb6f9 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -50,6 +50,7 @@ from .layer.activation import ThresholdedReLU # noqa: F401 from .layer.activation import LogSoftmax # noqa: F401 from .layer.activation import Maxout # noqa: F401 +from .layer.activation import RReLU # noqa: F401 from .layer.common import Pad1D # noqa: F401 from .layer.common import Pad2D # noqa: F401 from .layer.common import ZeroPad2D # noqa: F401 @@ -306,4 +307,5 @@ def weight_norm(*args): 'MaxUnPool2D', 'MaxUnPool3D', 'HingeEmbeddingLoss', + 'RRelu', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a24afc45a5995..1ba9b99f05254 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -47,6 +47,7 @@ from .activation import log_softmax # noqa: F401 from .activation import glu # noqa: F401 from .activation import gumbel_softmax # noqa: F401 +from .activation import rrelu # noqa: F401 from .common import dropout # noqa: F401 from .common import dropout2d # noqa: F401 from .common import dropout3d # noqa: F401 @@ -224,4 +225,5 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'rrelu', ] diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 11d2ad6fa8826..50e8477cc47f4 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -532,6 +532,46 @@ def prelu(x, weight, data_format="NCHW", name=None): "data_format": data_format}) return out +def rrelu(x, lower, upper, training=False, name=None): + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'rrelu') + + if not isinstance(lower, float) or not isinstance(upper, float): + raise TypeError( + "The lower and upper values must be float type. Received: lower {}, upper {}.".format( + lower, upper)) + + if lower < 0 or lower > 1: + raise ValueError( + "The lower value must be no less than zero or greater than one. Received: {}.".format( + lower)) + + if upper < lower: + raise ValueError( + "The upper value must be greater than lower value. Received: lower {}, upper {}.".format( + lower, upper)) + + if upper > 1: + raise ValueError( + "The upper value must be no greater than one. Received: {}.".format( + upper)) + + if training: + negative_slope = (lower + upper) / 2.0 + return leaky_relu(x, negative_slope, name) + + if in_dynamic_mode(): + return _C_ops.rrelu(x, 'lower', lower, 'upper', upper) + + helper = LayerHelper('rrelu', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type="rrelu", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"lower": lower, + "upper": upper}) + return out def relu(x, name=None): """ diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 400585c431830..590d38ea34075 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -435,6 +435,24 @@ def extra_repr(self): self._num_parameters, self._data_format, self._init, self._dtype, name_str) +class RReLU(Layer): + def __init__(self, + lower=1./8., + upper=1./3., + name=None): + super(PReLU, self).__init__() + self._lower = lower + self._upper = upper + self._name = name + + def forward(self, x): + return F.rrelu(x, lower=self._lower, upper=self._upper, training=self.training) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'lower={}, upper={}, training={}, dtype={}{}'.format( + self._lower, self._upper, self.training, self._dtype, + name_str) class ReLU(Layer): """ From b03c8d124c5383355a9a3043ebbf3acef0731749 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 14 Apr 2022 13:03:01 +0000 Subject: [PATCH 02/16] unregistered op kernel (unresolved) --- paddle/fluid/operators/rrelu_op.cc | 6 +- paddle/phi/infermeta/unary.cc | 41 ++- paddle/phi/infermeta/unary.h | 8 +- paddle/phi/kernels/cpu/rrelu_kernel.cc | 15 +- paddle/phi/kernels/gpu/rrelu_funcs.h | 39 ++- paddle/phi/kernels/gpu/rrelu_grad_kernel.cu | 25 +- paddle/phi/kernels/gpu/rrelu_kernel.cu | 10 +- paddle/phi/ops/compat/rrelu_sig.cc | 18 +- .../fluid/tests/unittests/test_rrelu_op.py | 262 ++++++++++++++++++ python/paddle/nn/functional/activation.py | 71 ++++- python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/activation.py | 66 ++++- tools/static_mode_white_list.py | 1 + 13 files changed, 451 insertions(+), 112 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_rrelu_op.py diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index b393c1a02e542..a87ae1b49101b 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -87,7 +87,7 @@ where :math:`a` is randomly sampled from uniform distribution } }; -class RReluOpGrad : public framework::OperatorWithKernel { +class RReluGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -120,7 +120,7 @@ class RReluGradOpMaker : public framework::SingleGradOpMaker { op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput("Noise", this->Output("Noise")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetAttrMap(this->Attrs()); +// op->SetAttrMap(this->Attrs()); } }; @@ -135,4 +135,4 @@ REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, ops::RReluGradOpMaker, ops::RReluGradOpMaker, RReluInferShapeFunctor); -REGISTER_OPERATOR(rrelu_grad, ops::RReluOpGrad); +REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index efc923bfc81c4..effd2a3836f16 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2037,28 +2037,25 @@ void RReluInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* noise) { auto x_dims = x.dims(); - PADDLE_ENFORCE_GE( - lower, - 0, - phi::errors::InvalidArgument( - "The lower value should be greater than or equal to 0. " - "But received lower value = %f.", - lower)); - PADDLE_ENFORCE_LE( - upper, - 1, - phi::errors::InvalidArgument( - "The upper value should be less than or equal to 1. " - "But received upper value = %f.", - upper)); - PADDLE_ENFORCE_GT( - upper, - lower, - phi::errors::InvalidArgument( - "The upper value should be greater than lower value " - "But received upper value = %f, lower value = %f.", - upper, - lower)); + PADDLE_ENFORCE_GE(lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE(upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GE(upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); out->set_dims(x_dims); out->set_dtype(x.dtype()); noise->set_dims(x_dims); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 6af02d8c48249..58e26d27ee081 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -301,9 +301,9 @@ void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); void RReluInferMeta(const MetaTensor& x, - float lower, - float upper, - MetaTensor* out, - MetaTensor* noise); + float lower, + float upper, + MetaTensor* out, + MetaTensor* noise); } // namespace phi diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc index 2311d7b0112ca..6aade18337d24 100644 --- a/paddle/phi/kernels/cpu/rrelu_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -29,21 +29,22 @@ void RReluKernel(const Context& dev_ctx, const T* x_ptr = x.data(); T* o_ptr = dev_ctx.template Alloc(out); T* n_ptr = dev_ctx.template Alloc(noise); + T zero = static_cast(0); - std::uniform_real_distribution dist(lower, upper); + std::uniform_real_distribution dist(lower, upper); auto gen_ptr = dev_ctx.GetGenerator(); auto engine = gen_ptr->GetCPUEngine(); int numel = x.numel(); int i = 0; for (i = 0; i < numel; i++) { - if (x_ptr[i] < 0) { - T scale = static_cast(dist(*engine)); - o_ptr[i] = scale * x_ptr[i]; - n_ptr[i] = scale; + if (x_ptr[i] < zero) { + T scale = static_cast(dist(*engine)); + o_ptr[i] = scale * x_ptr[i]; + n_ptr[i] = scale; } else { - o_ptr[i] = x_ptr[i]; - n_ptr[i] = 1.0; + o_ptr[i] = x_ptr[i]; + n_ptr[i] = 1.0; } } } diff --git a/paddle/phi/kernels/gpu/rrelu_funcs.h b/paddle/phi/kernels/gpu/rrelu_funcs.h index 2aab88e58b48c..40cf9cf1c71b4 100644 --- a/paddle/phi/kernels/gpu/rrelu_funcs.h +++ b/paddle/phi/kernels/gpu/rrelu_funcs.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once +#include +#include #include #include #include "paddle/phi/kernels/funcs/math_function.h" -#include -#include namespace phi { @@ -32,29 +32,28 @@ template __global__ void RReluElementWiseKernel(const T *input, T *output, T *noise, - const float& lower, - const float& upper, + const float &lower, + const float &upper, size_t numel) { CUDA_KERNEL_LOOP(index, numel) { T x = input[index]; T zero = static_cast(0); if (x < zero) { - thrust::minstd_rand rng; - rng.seed(0); - thrust::uniform_real_distribution dist(lower, upper); - rng.discard(index); - T scale = dist(rng); - output[index] = scale * x; - noise[index] = scale; + thrust::minstd_rand rng; + rng.seed(0); + thrust::uniform_real_distribution dist(lower, upper); + rng.discard(index); + T scale = static_cast(dist(rng)); + output[index] = scale * x; + noise[index] = scale; } else { - output[index] = x; - noise[index] = 1.0; + output[index] = x; + noise[index] = 1.0; } } } - template class RReluElementWiseDirectCUDAFunctor { public: @@ -62,8 +61,8 @@ class RReluElementWiseDirectCUDAFunctor { const T *input, T *output, T *noise, - const float& lower, - const float& upper, + const float &lower, + const float &upper, size_t numel); }; @@ -72,16 +71,16 @@ void RReluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, T *output, T *noise, - const float& lower, - const float& upper, + const float &lower, + const float &upper, size_t numel) { RReluElementWiseKernel<<>>( - input, output, noise, lower, upper, numel); + stream>>>(input, output, noise, lower, upper, numel); } template class RReluElementWiseDirectCUDAFunctor; template class RReluElementWiseDirectCUDAFunctor; +template class RReluElementWiseDirectCUDAFunctor; } // namespace phi diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu index 8c1a46f152ca5..3d60d630bb2b3 100644 --- a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -25,7 +25,7 @@ namespace phi { template -__global__ void PReluOpGradKernel(const T* x_ptr, +__global__ void RReluOpGradKernel(const T* x_ptr, const T* noise_ptr, const T* out_grad_ptr, T* x_grad_ptr, @@ -48,13 +48,9 @@ class RReluOpGradFunctor { const T* out_grad, T* x_grad, int numel) { - PReluOpGradKernel< + RReluOpGradKernel< T><<>>( - x, - noise, - out_grad, - x_grad, - numel); + x, noise, out_grad, x_grad, numel); } }; @@ -77,19 +73,10 @@ void RReluGradKernel(const Context& dev_ctx, auto stream = dev_ctx.stream(); RReluOpGradFunctor rrelu_grad; - rrelu_grad(stream, - x_ptr, - n_ptr, - out_grad_ptr, - x_grad_ptr, - numel); + rrelu_grad(stream, x_ptr, n_ptr, out_grad_ptr, x_grad_ptr, numel); } } // namespace phi -PD_REGISTER_KERNEL(rrelu_grad, - GPU, - ALL_LAYOUT, - phi::RReluGradKernel, - float, - double) {} +PD_REGISTER_KERNEL( + rrelu_grad, GPU, ALL_LAYOUT, phi::RReluGradKernel, float, phi::dtype::float16, double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index 8536c1d4e4a6a..de3c678e579c4 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -34,14 +34,10 @@ void RReluKernel(const Context& dev_ctx, int numel = x.numel(); auto dim = x.dims(); RReluElementWiseDirectCUDAFunctor rrelu_element_wise; - rrelu_element_wise(dev_ctx.stream(), x_ptr, o_ptr, n_ptr, lower, upper, numel); + rrelu_element_wise( + dev_ctx.stream(), x_ptr, o_ptr, n_ptr, lower, upper, numel); } } // namespace phi -PD_REGISTER_KERNEL(rrelu, - GPU, - ALL_LAYOUT, - phi::RReluKernel, - float, - double) {} +PD_REGISTER_KERNEL(rrelu, GPU, ALL_LAYOUT, phi::RReluKernel, float, phi::dtype::float16, double) {} diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc index 63043e499d78d..5f9232412b11d 100644 --- a/paddle/phi/ops/compat/rrelu_sig.cc +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -16,23 +16,17 @@ namespace phi { -KernelSignature RReluOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "rrelu", {"X"}, {"lower", "upper"}, {"Out", "Noise"}); +KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("rrelu", {"X"}, {"lower", "upper"}, {"Out", "Noise"}); } KernelSignature RReluGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("rrelu_grad", - {GradVarName("Out"), "Noise"}, - {}, - {GradVarName("X")}); + return KernelSignature( + "rrelu_grad", {GradVarName("Out"), "Noise"}, {}, {GradVarName("X")}); } } // namespace phi -PD_REGISTER_ARG_MAPPING_FN(rrelu, - phi::RReluOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, - phi::RReluGradGradOpArgumentMapping); \ No newline at end of file +PD_REGISTER_ARG_MAPPING_FN(rrelu, phi::RReluOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, phi::RReluGradGradOpArgumentMapping); \ No newline at end of file diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py new file mode 100644 index 0000000000000..bb24ce45e45f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -0,0 +1,262 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import six +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest +import paddle +import paddle.nn.functional as F + +import time +def debug_log(msg,is_clear=False): + fp = open('/tmp/data.txt', 'w' if is_clear else "a") + fp.write(str(time.time()) + " => " + msg + "\n") + fp.close() + +debug_log("=======> 111", True) + +xx= paddle.rand((2, 3)) +rrelu1 = paddle.nn.RReLU() +print(rrelu1(xx)) +print(F.rrelu(xx, 0.1, 0.4, training = True)) + +def ref_rrelu(x, lower, upper): + x_t = x.copy() + alpha = (lower + upper) / 2.0 + return np.where(x_t <= 0, alpha * x_t, x_t) + +def ref_rrelu_nn(x, lower, upper): + return ref_rrelu(x, lower, upper) + +def check_output(input, output, lower, upper): + lower_res = np.where(input <= 0, lower * input, input) + upper_res = np.where(input <= 0, upper * input, input) + return (output >= lower_res).all() and (output <= upper_res).all() + +class TestFunctionalRReluAPI(unittest.TestCase): + def setUp(self): + # self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + # ) else paddle.CPUPlace() + self.place = paddle.CPUPlace() + self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float32') + self.lower_0 = 0.05 + self.lower_1 = 0.1 + self.upper_0 = 0.25 + self.upper_1 = 0.33 + debug_log("=======> 222") + + def static_check(self, lower, upper): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.x_np.shape, 'float32') + out = F.rrelu(x, lower, upper) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, + fetch_list=[out]) + out_ref = ref_rrelu(self.x_np, lower, upper) + self.assertEqual(np.allclose(out_ref, res[0]), True) + + def dygraph_check(self, lower, upper): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out = F.rrelu(x, lower, upper) + out_ref = ref_rrelu(self.x_np, lower, upper) + self.assertEqual(np.allclose(out_ref, out.numpy()), True) + paddle.enable_static() + + def test_static_api(self): + self.static_check(self.lower_0, self.upper_0) + self.static_check(self.lower_1, self.upper_1) + + # def test_dygraph_api(self): + # self.dygraph_check(self.lower_0, self.upper_0) + # self.dygraph_check(self.lower_1, self.upper_1) + + def test_error_functional(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.rrelu, x=1, lower=self.lower_0, upper=self.upper_0) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.fluid.data( + name='x_int32', shape=[2, 3], dtype='int32') + self.assertRaises(TypeError, F.rrelu, x=x_int32, lower=self.lower_0, upper=self.upper_0) + x_bool = paddle.fluid.data( + name='x_bool', shape=[2, 3], dtype='int32') + self.assertRaises(TypeError, F.rrelu, x=x_bool, lower=self.lower_0, upper=self.upper_0) + # lower and upper must be float + x_fp32 = paddle.fluid.data( + name='x_fp32', shape=[2, 3], dtype='float32') + self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0, upper=0.5) + self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0.5, upper=1) + # lower and upper must be in (0, 1) + self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=-1., upper=0.5) + self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=2.) + # upper should not be less than lower + self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=0.2) + # support the input dtype is float16 + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[2, 3], dtype='float16') + F.rrelu(x=x_fp16, lower=self.lower_0, upper=self.upper_0) + + def test_error_layer(self): + def error_variable(): + # The input type must be Variable. + with paddle.fluid.dygraph.guard(): + x = 6 + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + rrelu(paddle.to_tensor(x)) + + def error_int_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("int32") + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + rrelu(paddle.to_tensor(x)) + + def error_lower_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0, 0.5) + rrelu(paddle.to_tensor(x)) + + def error_upper_dtype(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 1) + rrelu(paddle.to_tensor(x)) + + def error_lower_range(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(-1.0, 0.5) + rrelu(paddle.to_tensor(x)) + + def error_upper_range(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 2.0) + rrelu(paddle.to_tensor(x)) + + def error_lower_upper(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 3]).astype("float32") + rrelu = paddle.nn.RReLU(0.5, 0.2) + rrelu(paddle.to_tensor(x)) + + self.assertRaises(TypeError, error_variable) + # self.assertRaises(TypeError, error_int_dtype) + # self.assertRaises(TypeError, error_lower_dtype) + # self.assertRaises(TypeError, error_upper_dtype) + # self.assertRaises(ValueError, error_lower_range) + # self.assertRaises(ValueError, error_upper_range) + # self.assertRaises(ValueError, error_lower_upper) + + +# class TestRReluAPI(unittest.TestCase): +# def setUp(self): +# self.shape = [2, 3] +# self.x_1_np = np.random.random(self.shape).astype("float64") +# self.lower = 0.1 +# self.upper = 0.25 +# debug_log("=======> 333") +# +# def test_static_graph_functional(self): +# for use_cuda in ([False, True] +# if core.is_compiled_with_cuda() else [False]): +# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() +# paddle.enable_static() +# x_1 = paddle.fluid.data( +# name="X", shape=self.shape, dtype="float64") +# out_1 = F.rrelu(x_1, self.lower, self.upper) +# exe = paddle.static.Executor(place=place) +# res_1 = exe.run(fluid.default_main_program(), +# feed={"X": self.x_1_np}, +# fetch_list=out_1, +# use_prune=True) +# self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) +# +# # same test between layer and functional in this op. +# def test_static_graph_layer(self): +# for use_cuda in ([False, True] +# if core.is_compiled_with_cuda() else [False]): +# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() +# +# paddle.enable_static() +# x_1 = paddle.fluid.data( +# name="X", shape=self.shape, dtype="float64") +# +# # init instance +# ps_1 = paddle.nn.RReLU(self.lower, self.upper) +# out_1 = ps_1(x_1) +# exe = paddle.static.Executor(place=place) +# res_1 = exe.run(fluid.default_main_program(), +# feed={"X": self.x_1_np}, +# fetch_list=out_1, +# use_prune=True) +# self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) +# +# def test_dygraph(self): +# for use_cuda in ([False, True] +# if core.is_compiled_with_cuda() else [False]): +# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() +# +# paddle.disable_static(place=place) +# +# rrelu = paddle.nn.RReLU(self.lower, self.upper) +# result = rrelu(paddle.to_tensor(self.x_1_np)) +# self.assertTrue(check_output(self.x_1_np, result.numpy(), self.lower, self.upper)) +# result_functional = F.rrelu( +# paddle.to_tensor(self.x_1_np), self.lower, self.upper) +# self.assertTrue(check_output(self.x_1_np, result_functional.numpy(), self.lower, self.upper)) + + +# class RReluTest(OpTest): +# def setUp(self): +# self.init_alpha() +# self.init_dtype() +# self.init_input_shape() +# self.init_attr() +# self.op_type = "rrelu" +# +# x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) +# x_np[np.abs(x_np) < 0.005] = 0.02 +# out_np = ref_rrelu(x_np, self.alpha, self.alpha) +# self.inputs = {'X': x_np} +# self.outputs = {'Out': out_np} +# debug_log("=======> 444") +# +# def init_alpha(self): +# self.alpha = 0.5 +# +# def init_dtype(self): +# self.dtype = np.float64 +# +# def init_input_shape(self): +# self.x_shape = [2, 3] +# +# def init_attr(self): +# self.attrs = {'lower': self.alpha, "upper": self.alpha} +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# self.check_grad(['X'], 'Out') + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 50e8477cc47f4..cbc533f09fc4a 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -532,31 +532,80 @@ def prelu(x, weight, data_format="NCHW", name=None): "data_format": data_format}) return out -def rrelu(x, lower, upper, training=False, name=None): - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'rrelu') +def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): + """ + rrelu activation. + + .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. + + Parameters: + x (Tensor): The input Tensor with data type float 16 float32, float64. + lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. + upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. + training (bool, optional): Current is training mode or others. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') + x = paddle.to_tensor(data) + out = F.rrelu(x, 0.1, 0.3) + # [[[[-0.5 , 3. , -1. , 5. ], + # [ 3. , -1. , 5. , -1.5 ], + # [-1.75, -2. , 8. , 9. ]], + # [[ 1. , -0.5 , -0.75, 4. ], + # [-1.25, 6. , 7. , -2. ], + # [ 6. , 7. , 8. , 9. ]]]] + """ + check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], 'rrelu') if not isinstance(lower, float) or not isinstance(upper, float): raise TypeError( - "The lower and upper values must be float type. Received: lower {}, upper {}.".format( - lower, upper)) + "The lower and upper values must be float type. Received: lower {}, upper {}.". + format(lower, upper)) if lower < 0 or lower > 1: raise ValueError( - "The lower value must be no less than zero or greater than one. Received: {}.".format( - lower)) + "The lower value must be no less than zero or greater than one. Received: {}.". + format(lower)) if upper < lower: raise ValueError( - "The upper value must be greater than lower value. Received: lower {}, upper {}.".format( - lower, upper)) + "The upper value must be greater than lower value. Received: lower {}, upper {}.". + format(lower, upper)) if upper > 1: raise ValueError( "The upper value must be no greater than one. Received: {}.".format( upper)) - if training: + if not training: negative_slope = (lower + upper) / 2.0 return leaky_relu(x, negative_slope, name) @@ -565,14 +614,16 @@ def rrelu(x, lower, upper, training=False, name=None): helper = LayerHelper('rrelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) + noise = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type="rrelu", inputs={"X": x}, - outputs={"Out": out}, + outputs={"Out": out, "Noise": noise}, attrs={"lower": lower, "upper": upper}) return out + def relu(x, name=None): """ relu activation. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2b50508065605..d4b271e987490 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -26,6 +26,7 @@ from .activation import Sigmoid # noqa: F401 from .activation import Softmax # noqa: F401 from .activation import LogSoftmax # noqa: F401 +from .activation import RReLU # noqa: F401 from .common import Bilinear # noqa: F401 from .common import Pad1D # noqa: F401 from .common import Pad2D # noqa: F401 diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 590d38ea34075..03cbca8b646e5 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -435,24 +435,74 @@ def extra_repr(self): self._num_parameters, self._data_format, self._init, self._dtype, name_str) + class RReLU(Layer): - def __init__(self, - lower=1./8., - upper=1./3., - name=None): - super(PReLU, self).__init__() + """ + rrelu activation. + + `Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} + + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. + + Parameters: + lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. + upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. Default dtype is float32. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.set_default_dtype("float64") + + data = np.array([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], 'float64') + x = paddle.to_tensor(data) + m = paddle.nn.RReLU(0.1, 0.3) + out = m(x) + # [[[[-0.5 , 3. , -1. , 5. ], + # [ 3. , -1. , 5. , -1.5 ], + # [-1.75, -2. , 8. , 9. ]], + # [[ 1. , -0.5 , -0.75, 4. ], + # [-1.25, 6. , 7. , -2. ], + # [ 6. , 7. , 8. , 9. ]]]] + """ + + def __init__(self, lower=1. / 8., upper=1. / 3., name=None): + super(RReLU, self).__init__() self._lower = lower self._upper = upper self._name = name def forward(self, x): - return F.rrelu(x, lower=self._lower, upper=self._upper, training=self.training) + return F.rrelu( + x, lower=self._lower, upper=self._upper, training=self.training) def extra_repr(self): name_str = ', name={}'.format(self._name) if self._name else '' return 'lower={}, upper={}, training={}, dtype={}{}'.format( - self._lower, self._upper, self.training, self._dtype, - name_str) + self._lower, self._upper, self.training, self._dtype, name_str) + class ReLU(Layer): """ diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382..25360977daee2 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -394,6 +394,7 @@ 'test_positive_negative_pair_op', 'test_precision_recall_op', 'test_prelu_op', + 'test_rrelu_op', 'test_prelu_mkldnn_op', 'test_print_op', 'test_prior_box_op', From a83e2bbc71cc6029b383b1bc1e016abe046638a7 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 16 Apr 2022 02:29:05 +0000 Subject: [PATCH 03/16] commit before merge --- paddle/fluid/operators/rrelu_op.cc | 2 +- paddle/phi/infermeta/unary.cc | 23 +- paddle/phi/kernels/gpu/rrelu_grad_kernel.cu | 9 +- paddle/phi/kernels/gpu/rrelu_kernel.cu | 8 +- .../fluid/tests/unittests/test_rrelu_op.py | 221 ++++++++++-------- python/paddle/nn/functional/activation.py | 3 +- 6 files changed, 149 insertions(+), 117 deletions(-) diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index a87ae1b49101b..d244b646c712e 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -120,7 +120,7 @@ class RReluGradOpMaker : public framework::SingleGradOpMaker { op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput("Noise", this->Output("Noise")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); -// op->SetAttrMap(this->Attrs()); + // op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index effd2a3836f16..5381dafcb102d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2032,10 +2032,10 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } void RReluInferMeta(const MetaTensor& x, - float lower, - float upper, - MetaTensor* out, - MetaTensor* noise) { + float lower, + float upper, + MetaTensor* out, + MetaTensor* noise) { auto x_dims = x.dims(); PADDLE_ENFORCE_GE(lower, 0, @@ -2049,13 +2049,14 @@ void RReluInferMeta(const MetaTensor& x, "The upper value should be less than or equal to 1. " "But received upper value = %f.", upper)); - PADDLE_ENFORCE_GE(upper, - lower, - phi::errors::InvalidArgument( - "The upper value should be greater than or equal to lower value " - "But received upper value = %f, lower value = %f.", - upper, - lower)); + PADDLE_ENFORCE_GE( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); out->set_dims(x_dims); out->set_dtype(x.dtype()); noise->set_dims(x_dims); diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu index 3d60d630bb2b3..5256b8f13624b 100644 --- a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -78,5 +78,10 @@ void RReluGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - rrelu_grad, GPU, ALL_LAYOUT, phi::RReluGradKernel, float, phi::dtype::float16, double) {} +PD_REGISTER_KERNEL(rrelu_grad, + GPU, + ALL_LAYOUT, + phi::RReluGradKernel, + float, + phi::dtype::float16, + double) {} diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index de3c678e579c4..877d1acf7e3d1 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -40,4 +40,10 @@ void RReluKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(rrelu, GPU, ALL_LAYOUT, phi::RReluKernel, float, phi::dtype::float16, double) {} +PD_REGISTER_KERNEL(rrelu, + GPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + phi::dtype::float16, + double) {} diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index bb24ce45e45f5..4ca0d3a4d814f 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -25,31 +25,38 @@ import paddle.nn.functional as F import time -def debug_log(msg,is_clear=False): + + +def debug_log(msg, is_clear=False): fp = open('/tmp/data.txt', 'w' if is_clear else "a") fp.write(str(time.time()) + " => " + msg + "\n") fp.close() + debug_log("=======> 111", True) -xx= paddle.rand((2, 3)) +xx = paddle.rand((2, 3)) rrelu1 = paddle.nn.RReLU() print(rrelu1(xx)) -print(F.rrelu(xx, 0.1, 0.4, training = True)) +print(F.rrelu(xx, 0.1, 0.4, training=True)) + def ref_rrelu(x, lower, upper): x_t = x.copy() alpha = (lower + upper) / 2.0 return np.where(x_t <= 0, alpha * x_t, x_t) + def ref_rrelu_nn(x, lower, upper): return ref_rrelu(x, lower, upper) + def check_output(input, output, lower, upper): lower_res = np.where(input <= 0, lower * input, input) upper_res = np.where(input <= 0, upper * input, input) return (output >= lower_res).all() and (output <= upper_res).all() + class TestFunctionalRReluAPI(unittest.TestCase): def setUp(self): # self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( @@ -67,8 +74,7 @@ def static_check(self, lower, upper): x = paddle.fluid.data('X', self.x_np.shape, 'float32') out = F.rrelu(x, lower, upper) exe = paddle.static.Executor(self.place) - res = exe.run(feed={'X': self.x_np}, - fetch_list=[out]) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) out_ref = ref_rrelu(self.x_np, lower, upper) self.assertEqual(np.allclose(out_ref, res[0]), True) @@ -91,24 +97,38 @@ def test_static_api(self): def test_error_functional(self): with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, F.rrelu, x=1, lower=self.lower_0, upper=self.upper_0) + self.assertRaises( + TypeError, F.rrelu, x=1, lower=self.lower_0, upper=self.upper_0) # The input dtype must be float16, float32, float64. x_int32 = paddle.fluid.data( name='x_int32', shape=[2, 3], dtype='int32') - self.assertRaises(TypeError, F.rrelu, x=x_int32, lower=self.lower_0, upper=self.upper_0) + self.assertRaises( + TypeError, + F.rrelu, + x=x_int32, + lower=self.lower_0, + upper=self.upper_0) x_bool = paddle.fluid.data( name='x_bool', shape=[2, 3], dtype='int32') - self.assertRaises(TypeError, F.rrelu, x=x_bool, lower=self.lower_0, upper=self.upper_0) + self.assertRaises( + TypeError, + F.rrelu, + x=x_bool, + lower=self.lower_0, + upper=self.upper_0) # lower and upper must be float x_fp32 = paddle.fluid.data( name='x_fp32', shape=[2, 3], dtype='float32') self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0, upper=0.5) self.assertRaises(TypeError, F.rrelu, x=x_fp32, lower=0.5, upper=1) # lower and upper must be in (0, 1) - self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=-1., upper=0.5) - self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=2.) + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=-1., upper=0.5) + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=2.) # upper should not be less than lower - self.assertRaises(ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=0.2) + self.assertRaises( + ValueError, F.rrelu, x=x_fp32, lower=0.5, upper=0.2) # support the input dtype is float16 x_fp16 = paddle.fluid.data( name='x_fp16', shape=[2, 3], dtype='float16') @@ -166,97 +186,96 @@ def error_lower_upper(): # self.assertRaises(ValueError, error_upper_range) # self.assertRaises(ValueError, error_lower_upper) + # class TestRReluAPI(unittest.TestCase): + # def setUp(self): + # self.shape = [2, 3] + # self.x_1_np = np.random.random(self.shape).astype("float64") + # self.lower = 0.1 + # self.upper = 0.25 + # debug_log("=======> 333") + # + # def test_static_graph_functional(self): + # for use_cuda in ([False, True] + # if core.is_compiled_with_cuda() else [False]): + # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + # paddle.enable_static() + # x_1 = paddle.fluid.data( + # name="X", shape=self.shape, dtype="float64") + # out_1 = F.rrelu(x_1, self.lower, self.upper) + # exe = paddle.static.Executor(place=place) + # res_1 = exe.run(fluid.default_main_program(), + # feed={"X": self.x_1_np}, + # fetch_list=out_1, + # use_prune=True) + # self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) + # + # # same test between layer and functional in this op. + # def test_static_graph_layer(self): + # for use_cuda in ([False, True] + # if core.is_compiled_with_cuda() else [False]): + # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + # + # paddle.enable_static() + # x_1 = paddle.fluid.data( + # name="X", shape=self.shape, dtype="float64") + # + # # init instance + # ps_1 = paddle.nn.RReLU(self.lower, self.upper) + # out_1 = ps_1(x_1) + # exe = paddle.static.Executor(place=place) + # res_1 = exe.run(fluid.default_main_program(), + # feed={"X": self.x_1_np}, + # fetch_list=out_1, + # use_prune=True) + # self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) + # + # def test_dygraph(self): + # for use_cuda in ([False, True] + # if core.is_compiled_with_cuda() else [False]): + # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + # + # paddle.disable_static(place=place) + # + # rrelu = paddle.nn.RReLU(self.lower, self.upper) + # result = rrelu(paddle.to_tensor(self.x_1_np)) + # self.assertTrue(check_output(self.x_1_np, result.numpy(), self.lower, self.upper)) + # result_functional = F.rrelu( + # paddle.to_tensor(self.x_1_np), self.lower, self.upper) + # self.assertTrue(check_output(self.x_1_np, result_functional.numpy(), self.lower, self.upper)) + + # class RReluTest(OpTest): + # def setUp(self): + # self.init_alpha() + # self.init_dtype() + # self.init_input_shape() + # self.init_attr() + # self.op_type = "rrelu" + # + # x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) + # x_np[np.abs(x_np) < 0.005] = 0.02 + # out_np = ref_rrelu(x_np, self.alpha, self.alpha) + # self.inputs = {'X': x_np} + # self.outputs = {'Out': out_np} + # debug_log("=======> 444") + # + # def init_alpha(self): + # self.alpha = 0.5 + # + # def init_dtype(self): + # self.dtype = np.float64 + # + # def init_input_shape(self): + # self.x_shape = [2, 3] + # + # def init_attr(self): + # self.attrs = {'lower': self.alpha, "upper": self.alpha} + # + # def test_check_output(self): + # self.check_output() + # + # def test_check_grad(self): + # self.check_grad(['X'], 'Out') -# class TestRReluAPI(unittest.TestCase): -# def setUp(self): -# self.shape = [2, 3] -# self.x_1_np = np.random.random(self.shape).astype("float64") -# self.lower = 0.1 -# self.upper = 0.25 -# debug_log("=======> 333") -# -# def test_static_graph_functional(self): -# for use_cuda in ([False, True] -# if core.is_compiled_with_cuda() else [False]): -# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() -# paddle.enable_static() -# x_1 = paddle.fluid.data( -# name="X", shape=self.shape, dtype="float64") -# out_1 = F.rrelu(x_1, self.lower, self.upper) -# exe = paddle.static.Executor(place=place) -# res_1 = exe.run(fluid.default_main_program(), -# feed={"X": self.x_1_np}, -# fetch_list=out_1, -# use_prune=True) -# self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) -# -# # same test between layer and functional in this op. -# def test_static_graph_layer(self): -# for use_cuda in ([False, True] -# if core.is_compiled_with_cuda() else [False]): -# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() -# -# paddle.enable_static() -# x_1 = paddle.fluid.data( -# name="X", shape=self.shape, dtype="float64") -# -# # init instance -# ps_1 = paddle.nn.RReLU(self.lower, self.upper) -# out_1 = ps_1(x_1) -# exe = paddle.static.Executor(place=place) -# res_1 = exe.run(fluid.default_main_program(), -# feed={"X": self.x_1_np}, -# fetch_list=out_1, -# use_prune=True) -# self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) -# -# def test_dygraph(self): -# for use_cuda in ([False, True] -# if core.is_compiled_with_cuda() else [False]): -# place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() -# -# paddle.disable_static(place=place) -# -# rrelu = paddle.nn.RReLU(self.lower, self.upper) -# result = rrelu(paddle.to_tensor(self.x_1_np)) -# self.assertTrue(check_output(self.x_1_np, result.numpy(), self.lower, self.upper)) -# result_functional = F.rrelu( -# paddle.to_tensor(self.x_1_np), self.lower, self.upper) -# self.assertTrue(check_output(self.x_1_np, result_functional.numpy(), self.lower, self.upper)) - - -# class RReluTest(OpTest): -# def setUp(self): -# self.init_alpha() -# self.init_dtype() -# self.init_input_shape() -# self.init_attr() -# self.op_type = "rrelu" -# -# x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) -# x_np[np.abs(x_np) < 0.005] = 0.02 -# out_np = ref_rrelu(x_np, self.alpha, self.alpha) -# self.inputs = {'X': x_np} -# self.outputs = {'Out': out_np} -# debug_log("=======> 444") -# -# def init_alpha(self): -# self.alpha = 0.5 -# -# def init_dtype(self): -# self.dtype = np.float64 -# -# def init_input_shape(self): -# self.x_shape = [2, 3] -# -# def init_attr(self): -# self.attrs = {'lower': self.alpha, "upper": self.alpha} -# -# def test_check_output(self): -# self.check_output() -# -# def test_check_grad(self): -# self.check_grad(['X'], 'Out') if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index cbc533f09fc4a..a8da98b97f03d 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -618,7 +618,8 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): helper.append_op( type="rrelu", inputs={"X": x}, - outputs={"Out": out, "Noise": noise}, + outputs={"Out": out, + "Noise": noise}, attrs={"lower": lower, "upper": upper}) return out From e5f3910dc141e866f1f356a4fa70b79b1c5ab3eb Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 26 Apr 2022 02:43:40 +0000 Subject: [PATCH 04/16] =?UTF-8?q?=E4=B8=B0=E5=AF=8C=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 25 +- paddle/phi/infermeta/unary.cc | 74 ++-- paddle/phi/infermeta/unary.h | 16 +- paddle/phi/kernels/cpu/rrelu_grad_kernel.cc | 3 +- paddle/phi/kernels/cpu/rrelu_kernel.cc | 28 +- paddle/phi/kernels/gpu/rrelu_funcs.h | 54 ++- paddle/phi/kernels/gpu/rrelu_grad_kernel.cu | 1 - paddle/phi/kernels/gpu/rrelu_kernel.cu | 16 +- paddle/phi/kernels/rrelu_kernel.h | 3 + paddle/phi/ops/compat/rrelu_sig.cc | 10 +- .../fluid/tests/unittests/test_rrelu_op.py | 322 ++++++++++-------- python/paddle/nn/functional/activation.py | 56 ++- python/paddle/nn/layer/activation.py | 12 +- 13 files changed, 374 insertions(+), 246 deletions(-) diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index d244b646c712e..f22a07452d4a6 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -43,6 +43,19 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Noise", "The random sampled RRelu noise.") .AsIntermediate() .AsExtra(); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(false) + .AsExtra(); + AddAttr("seed", "Rrelu random seed.").SetDefault(0).AsExtra(); float default_lower = 1. / 8.; AddAttr("lower", "Lower bound of the uniform distribution.") @@ -60,7 +73,6 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker { platform::errors::InvalidArgument( "'RRelu_upper' must be between 0.0 and 1.0.")); }); - AddComment(R"DOC( RRelu Operator. @@ -92,12 +104,15 @@ class RReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Noise"), "Input", "Noise", "rrelu_grad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "rrelu"); + OP_INOUT_CHECK(ctx->HasInput("Noise"), "Input", "Noise", "rrelu"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "rrelu_grad"); + framework::GradVarName("Out"), "rrelu"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->SetOutputDim(framework::GradVarName("X"), out_dims); + ctx->ShareLoD(framework::GradVarName("Out"), + /*->*/ framework::GradVarName("X")); } protected: @@ -117,10 +132,10 @@ class RReluGradOpMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr op) const override { op->SetType("rrelu_grad"); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("X", this->Input("X")); op->SetInput("Noise", this->Output("Noise")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - // op->SetAttrMap(this->Attrs()); } }; diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d01a6f0aea164..313d6a85cf9b3 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1816,6 +1816,48 @@ void RollInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + bool fix_seed, + int seed, + MetaTensor* out, + MetaTensor* noise) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE(lower, + 0, + phi::errors::InvalidArgument( + "The lower value should be greater than or equal to 0. " + "But received lower value = %f.", + lower)); + PADDLE_ENFORCE_LE(upper, + 1, + phi::errors::InvalidArgument( + "The upper value should be less than or equal to 1. " + "But received upper value = %f.", + upper)); + PADDLE_ENFORCE_GE( + upper, + lower, + phi::errors::InvalidArgument( + "The upper value should be greater than or equal to lower value " + "But received upper value = %f, lower value = %f.", + upper, + lower)); + + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); + + if (noise != nullptr) { + noise->set_dims(x_dims); + noise->set_dtype(x.dtype()); + noise->set_layout(x.layout()); + } +} + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { auto in_dims = x.dims(); PADDLE_ENFORCE_LT( @@ -3000,38 +3042,6 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { out->set_dtype(DataType::INT64); } -void RReluInferMeta(const MetaTensor& x, - float lower, - float upper, - MetaTensor* out, - MetaTensor* noise) { - auto x_dims = x.dims(); - PADDLE_ENFORCE_GE(lower, - 0, - phi::errors::InvalidArgument( - "The lower value should be greater than or equal to 0. " - "But received lower value = %f.", - lower)); - PADDLE_ENFORCE_LE(upper, - 1, - phi::errors::InvalidArgument( - "The upper value should be less than or equal to 1. " - "But received upper value = %f.", - upper)); - PADDLE_ENFORCE_GE( - upper, - lower, - phi::errors::InvalidArgument( - "The upper value should be greater than or equal to lower value " - "But received upper value = %f, lower value = %f.", - upper, - lower)); - out->set_dims(x_dims); - out->set_dtype(x.dtype()); - noise->set_dims(x_dims); - noise->set_dtype(x.dtype()); -} - } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 760eadb0966ba..9d096108b1ebb 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -269,6 +269,15 @@ void RollInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void RReluInferMeta(const MetaTensor& x, + float lower, + float upper, + bool is_test, + bool fix_seed, + int seed, + MetaTensor* out, + MetaTensor* noise); + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); @@ -439,11 +448,4 @@ void OneHotRawInferMeta(const MetaTensor& x, void OneHotInferMeta(const MetaTensor& x, const Scalar& depth, MetaTensor* out); void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out); - -void RReluInferMeta(const MetaTensor& x, - float lower, - float upper, - MetaTensor* out, - MetaTensor* noise); - } // namespace phi diff --git a/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc index c6c14510293fb..10b6c6b1a3ea8 100644 --- a/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_grad_kernel.cc @@ -29,12 +29,13 @@ void RReluGradKernel(const Context& dev_ctx, const T* x_ptr = x.data(); const T* out_grad_ptr = out_grad.data(); int numel = x.numel(); + if (!x_grad) return; + int i = 0; T* x_grad_ptr = dev_ctx.template Alloc(x_grad); for (i = 0; i < numel; i++) { x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i] : n_ptr[i] * out_grad_ptr[i]; } - } } // namespace phi diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc index 6aade18337d24..0d58141e2cbec 100644 --- a/paddle/phi/kernels/cpu/rrelu_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/rrelu_kernel.h" +#include "paddle/fluid/framework/generator.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -24,19 +25,38 @@ void RReluKernel(const Context& dev_ctx, const DenseTensor& x, const float lower, const float upper, + bool is_test, + bool fix_seed, + int seed, DenseTensor* out, DenseTensor* noise) { const T* x_ptr = x.data(); T* o_ptr = dev_ctx.template Alloc(out); T* n_ptr = dev_ctx.template Alloc(noise); T zero = static_cast(0); + int numel = x.numel(); + int i = 0; + + if (is_test) { + for (i = 0; i < numel; i++) { + T mid_val = static_cast((lower + upper) / 2.0); + if (x_ptr[i] < zero) { + o_ptr[i] = mid_val * x_ptr[i]; + n_ptr[i] = mid_val; + } else { + o_ptr[i] = x_ptr[i]; + n_ptr[i] = 1.0; + } + } + + return; + } + + int seed_data = fix_seed ? seed : 0; + auto engine = paddle::framework::GetCPURandomEngine(seed_data); std::uniform_real_distribution dist(lower, upper); - auto gen_ptr = dev_ctx.GetGenerator(); - auto engine = gen_ptr->GetCPUEngine(); - int numel = x.numel(); - int i = 0; for (i = 0; i < numel; i++) { if (x_ptr[i] < zero) { T scale = static_cast(dist(*engine)); diff --git a/paddle/phi/kernels/gpu/rrelu_funcs.h b/paddle/phi/kernels/gpu/rrelu_funcs.h index 40cf9cf1c71b4..f768a56225506 100644 --- a/paddle/phi/kernels/gpu/rrelu_funcs.h +++ b/paddle/phi/kernels/gpu/rrelu_funcs.h @@ -32,24 +32,38 @@ template __global__ void RReluElementWiseKernel(const T *input, T *output, T *noise, - const float &lower, - const float &upper, + const T mid_val, + const float lower, + const float upper, + const bool is_test, + const int seed_data, size_t numel) { CUDA_KERNEL_LOOP(index, numel) { T x = input[index]; T zero = static_cast(0); - if (x < zero) { - thrust::minstd_rand rng; - rng.seed(0); - thrust::uniform_real_distribution dist(lower, upper); - rng.discard(index); - T scale = static_cast(dist(rng)); - output[index] = scale * x; - noise[index] = scale; + if (is_test) { + if (x < zero) { + output[index] = mid_val * x; + noise[index] = mid_val; + } else { + output[index] = x; + noise[index] = 1.0; + } + } else { - output[index] = x; - noise[index] = 1.0; + if (x < zero) { + thrust::minstd_rand rng; + rng.seed(seed_data); + thrust::uniform_real_distribution dist(lower, upper); + rng.discard(index); + T scale = static_cast(dist(rng)); + output[index] = scale * x; + noise[index] = scale; + } else { + output[index] = x; + noise[index] = 1.0; + } } } } @@ -61,8 +75,10 @@ class RReluElementWiseDirectCUDAFunctor { const T *input, T *output, T *noise, - const float &lower, - const float &upper, + const float lower, + const float upper, + const bool is_test, + const int seed_data, size_t numel); }; @@ -71,13 +87,17 @@ void RReluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, const T *input, T *output, T *noise, - const float &lower, - const float &upper, + const float lower, + const float upper, + const bool is_test, + const int seed_data, size_t numel) { + T mid_val = static_cast((lower + upper) / 2.0); RReluElementWiseKernel<<>>(input, output, noise, lower, upper, numel); + stream>>>( + input, output, noise, mid_val, lower, upper, is_test, seed_data, numel); } template class RReluElementWiseDirectCUDAFunctor; diff --git a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu index 5256b8f13624b..44dc31ed5d926 100644 --- a/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_grad_kernel.cu @@ -66,7 +66,6 @@ void RReluGradKernel(const Context& dev_ctx, const T* x_ptr = x.data(); const T* n_ptr = noise.data(); const T* out_grad_ptr = out_grad.data(); - if (!x_grad) return; T* x_grad_ptr = dev_ctx.template Alloc(x_grad); int numel = x.numel(); diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index 877d1acf7e3d1..d41d55aab7153 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -25,6 +25,9 @@ void RReluKernel(const Context& dev_ctx, const DenseTensor& x, const float lower, const float upper, + bool is_test, + bool fix_seed, + int seed, DenseTensor* out, DenseTensor* noise) { const T* x_ptr = x.data(); @@ -34,8 +37,17 @@ void RReluKernel(const Context& dev_ctx, int numel = x.numel(); auto dim = x.dims(); RReluElementWiseDirectCUDAFunctor rrelu_element_wise; - rrelu_element_wise( - dev_ctx.stream(), x_ptr, o_ptr, n_ptr, lower, upper, numel); + + int seed_data = fix_seed ? seed : 0; + rrelu_element_wise(dev_ctx.stream(), + x_ptr, + o_ptr, + n_ptr, + lower, + upper, + is_test, + seed_data, + numel); } } // namespace phi diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h index 92a61ed15b6f7..50807ffc93f6b 100644 --- a/paddle/phi/kernels/rrelu_kernel.h +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -23,6 +23,9 @@ void RReluKernel(const Context& dev_ctx, const DenseTensor& x, const float lower, const float upper, + bool is_test, + bool fix_seed, + int seed, DenseTensor* out, DenseTensor* noise); } // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc index 5f9232412b11d..c3ff52c1715bb 100644 --- a/paddle/phi/ops/compat/rrelu_sig.cc +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -17,16 +17,18 @@ namespace phi { KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("rrelu", {"X"}, {"lower", "upper"}, {"Out", "Noise"}); + return KernelSignature("rrelu", + {"X"}, + {"lower", "upper", "is_test", "fix_seed", "seed"}, + {"Out", "Noise"}); } KernelSignature RReluGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "rrelu_grad", {GradVarName("Out"), "Noise"}, {}, {GradVarName("X")}); + "rrelu_grad", {"X", "Noise", GradVarName("Out")}, {}, {GradVarName("X")}); } - } // namespace phi PD_REGISTER_ARG_MAPPING_FN(rrelu, phi::RReluOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, phi::RReluGradGradOpArgumentMapping); \ No newline at end of file +PD_REGISTER_ARG_MAPPING_FN(rrelu_grad, phi::RReluGradGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index 4ca0d3a4d814f..d8cd29e25c6ae 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -23,22 +23,10 @@ from op_test import OpTest import paddle import paddle.nn.functional as F +from paddle.fluid import dygraph -import time - - -def debug_log(msg, is_clear=False): - fp = open('/tmp/data.txt', 'w' if is_clear else "a") - fp.write(str(time.time()) + " => " + msg + "\n") - fp.close() - - -debug_log("=======> 111", True) - -xx = paddle.rand((2, 3)) -rrelu1 = paddle.nn.RReLU() -print(rrelu1(xx)) -print(F.rrelu(xx, 0.1, 0.4, training=True)) +paddle.seed(102) +np.random.seed(102) def ref_rrelu(x, lower, upper): @@ -54,45 +42,143 @@ def ref_rrelu_nn(x, lower, upper): def check_output(input, output, lower, upper): lower_res = np.where(input <= 0, lower * input, input) upper_res = np.where(input <= 0, upper * input, input) - return (output >= lower_res).all() and (output <= upper_res).all() + return (output <= lower_res).all() and (output >= upper_res).all() class TestFunctionalRReluAPI(unittest.TestCase): def setUp(self): - # self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( - # ) else paddle.CPUPlace() - self.place = paddle.CPUPlace() - self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float32') + self.x_np = np.random.uniform(-1., 1., [1, 2, 3, 4]).astype('float64') self.lower_0 = 0.05 self.lower_1 = 0.1 self.upper_0 = 0.25 self.upper_1 = 0.33 - debug_log("=======> 222") - def static_check(self, lower, upper): - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.fluid.data('X', self.x_np.shape, 'float32') - out = F.rrelu(x, lower, upper) - exe = paddle.static.Executor(self.place) - res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) - out_ref = ref_rrelu(self.x_np, lower, upper) - self.assertEqual(np.allclose(out_ref, res[0]), True) + self.places = [fluid.CUDAPlace(0)] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 4, 5], dtype="float32") + res1 = F.rrelu( + x=input, lower=self.lower_0, upper=self.upper_0, training=False) + res2 = F.rrelu( + x=input, lower=self.lower_1, upper=self.upper_1, training=False) + in_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype("float32") + + res_np1 = ref_rrelu(in_np, self.lower_0, self.upper_0) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res1]) + + self.assertTrue(np.allclose(fetches[0], res_np1)) + + res_np2 = ref_rrelu(in_np, self.lower_1, self.upper_1) + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res2]) + self.assertTrue(np.allclose(fetches[0], res_np2)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_static_graph_functional(self): + '''test_static_graph_functional''' + + for place in self.places: + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=self.x_np.shape, dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=self.x_np.shape, dtype="float64") + out_1 = F.rrelu(x_1, self.lower_0, self.upper_0, training=False) + out_2 = F.rrelu(x_2, self.lower_1, self.upper_1, training=False) + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=out_1, + use_prune=True) + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_2, + use_prune=True) + + out_ref_1 = ref_rrelu(self.x_np, self.lower_0, self.upper_0) + out_ref_2 = ref_rrelu(self.x_np, self.lower_1, self.upper_1) + self.assertEqual(np.allclose(out_ref_1, res_1), True) + self.assertEqual(np.allclose(out_ref_2, res_2), True) + + def test_static_graph_layer(self): + '''test_static_graph_layer''' + + for place in self.places: + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=self.x_np.shape, dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=self.x_np.shape, dtype="float64") + # init instance + rrelu_1 = paddle.nn.RReLU(self.lower_0, self.upper_0) + rrelu_2 = paddle.nn.RReLU(self.lower_1, self.upper_1) + out_1 = rrelu_1(x_1) + out_2 = rrelu_2(x_2) + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=out_1, + use_prune=True) + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_2, + use_prune=True) + + self.assertTrue( + check_output(self.x_np, res_1[0], self.lower_0, self.upper_0)) + self.assertTrue( + check_output(self.x_np, res_2[0], self.lower_1, self.upper_1)) def dygraph_check(self, lower, upper): - paddle.disable_static(self.place) - x = paddle.to_tensor(self.x_np) - out = F.rrelu(x, lower, upper) - out_ref = ref_rrelu(self.x_np, lower, upper) - self.assertEqual(np.allclose(out_ref, out.numpy()), True) - paddle.enable_static() - - def test_static_api(self): - self.static_check(self.lower_0, self.upper_0) - self.static_check(self.lower_1, self.upper_1) - - # def test_dygraph_api(self): - # self.dygraph_check(self.lower_0, self.upper_0) - # self.dygraph_check(self.lower_1, self.upper_1) + for place in self.places: + paddle.disable_static(place) + x = paddle.to_tensor(self.x_np) + out = F.rrelu(x, lower, upper, training=False) + out_ref = ref_rrelu(self.x_np, lower, upper) + self.assertEqual(np.allclose(out_ref, out), True) + paddle.enable_static() + + def test_dygraph_functional(self): + '''test_dygraph_functional''' + + self.dygraph_check(self.lower_0, self.upper_0) + self.dygraph_check(self.lower_1, self.upper_1) + + def test_dygraph_layer(self): + '''test_dygraph_layer''' + + for place in self.places: + paddle.disable_static(place=place) + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + result = rrelu(paddle.to_tensor(self.x_np)) + self.assertTrue( + check_output(self.x_np, + result.numpy(), self.lower_0, self.upper_0)) + paddle.enable_static() + + def test_dygraph(self): + for place in self.places: + paddle.disable_static(place=place) + with dygraph.guard(): + rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + out_np = rrelu(paddle.to_tensor(self.x_np)) + self.assertTrue( + check_output(self.x_np, + out_np.numpy(), self.lower_0, self.upper_0)) + paddle.enable_static() def test_error_functional(self): with paddle.static.program_guard(paddle.static.Program()): @@ -135,17 +221,10 @@ def test_error_functional(self): F.rrelu(x=x_fp16, lower=self.lower_0, upper=self.upper_0) def test_error_layer(self): - def error_variable(): - # The input type must be Variable. - with paddle.fluid.dygraph.guard(): - x = 6 - rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) - rrelu(paddle.to_tensor(x)) - def error_int_dtype(): with paddle.fluid.dygraph.guard(): - x = np.random.random([2, 3]).astype("int32") - rrelu = paddle.nn.RReLU(self.lower_0, self.upper_0) + x = np.random.random([2, 3]).astype("float64") + rrelu = paddle.nn.RReLU(2, 3) rrelu(paddle.to_tensor(x)) def error_lower_dtype(): @@ -178,103 +257,48 @@ def error_lower_upper(): rrelu = paddle.nn.RReLU(0.5, 0.2) rrelu(paddle.to_tensor(x)) - self.assertRaises(TypeError, error_variable) - # self.assertRaises(TypeError, error_int_dtype) - # self.assertRaises(TypeError, error_lower_dtype) - # self.assertRaises(TypeError, error_upper_dtype) - # self.assertRaises(ValueError, error_lower_range) - # self.assertRaises(ValueError, error_upper_range) - # self.assertRaises(ValueError, error_lower_upper) - - # class TestRReluAPI(unittest.TestCase): - # def setUp(self): - # self.shape = [2, 3] - # self.x_1_np = np.random.random(self.shape).astype("float64") - # self.lower = 0.1 - # self.upper = 0.25 - # debug_log("=======> 333") - # - # def test_static_graph_functional(self): - # for use_cuda in ([False, True] - # if core.is_compiled_with_cuda() else [False]): - # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() - # paddle.enable_static() - # x_1 = paddle.fluid.data( - # name="X", shape=self.shape, dtype="float64") - # out_1 = F.rrelu(x_1, self.lower, self.upper) - # exe = paddle.static.Executor(place=place) - # res_1 = exe.run(fluid.default_main_program(), - # feed={"X": self.x_1_np}, - # fetch_list=out_1, - # use_prune=True) - # self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) - # - # # same test between layer and functional in this op. - # def test_static_graph_layer(self): - # for use_cuda in ([False, True] - # if core.is_compiled_with_cuda() else [False]): - # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() - # - # paddle.enable_static() - # x_1 = paddle.fluid.data( - # name="X", shape=self.shape, dtype="float64") - # - # # init instance - # ps_1 = paddle.nn.RReLU(self.lower, self.upper) - # out_1 = ps_1(x_1) - # exe = paddle.static.Executor(place=place) - # res_1 = exe.run(fluid.default_main_program(), - # feed={"X": self.x_1_np}, - # fetch_list=out_1, - # use_prune=True) - # self.assertTrue(check_output(self.x_1_np, res_1, self.lower, self.upper)) - # - # def test_dygraph(self): - # for use_cuda in ([False, True] - # if core.is_compiled_with_cuda() else [False]): - # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() - # - # paddle.disable_static(place=place) - # - # rrelu = paddle.nn.RReLU(self.lower, self.upper) - # result = rrelu(paddle.to_tensor(self.x_1_np)) - # self.assertTrue(check_output(self.x_1_np, result.numpy(), self.lower, self.upper)) - # result_functional = F.rrelu( - # paddle.to_tensor(self.x_1_np), self.lower, self.upper) - # self.assertTrue(check_output(self.x_1_np, result_functional.numpy(), self.lower, self.upper)) - - # class RReluTest(OpTest): - # def setUp(self): - # self.init_alpha() - # self.init_dtype() - # self.init_input_shape() - # self.init_attr() - # self.op_type = "rrelu" - # - # x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) - # x_np[np.abs(x_np) < 0.005] = 0.02 - # out_np = ref_rrelu(x_np, self.alpha, self.alpha) - # self.inputs = {'X': x_np} - # self.outputs = {'Out': out_np} - # debug_log("=======> 444") - # - # def init_alpha(self): - # self.alpha = 0.5 - # - # def init_dtype(self): - # self.dtype = np.float64 - # - # def init_input_shape(self): - # self.x_shape = [2, 3] - # - # def init_attr(self): - # self.attrs = {'lower': self.alpha, "upper": self.alpha} - # - # def test_check_output(self): - # self.check_output() - # - # def test_check_grad(self): - # self.check_grad(['X'], 'Out') + self.assertRaises(TypeError, error_int_dtype) + self.assertRaises(TypeError, error_lower_dtype) + self.assertRaises(TypeError, error_upper_dtype) + self.assertRaises(ValueError, error_lower_range) + self.assertRaises(ValueError, error_upper_range) + self.assertRaises(ValueError, error_lower_upper) + + +class RReluTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.init_boundary() + self.init_dtype() + self.init_input_shape() + self.init_attr() + + x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) + out_np = ref_rrelu(x_np, self.lower, self.upper) + noise_np = np.ones(self.x_shape).astype(self.dtype) + noise_np[x_np < 0] = (self.lower + self.upper) / 2.0 + + self.inputs = {'X': x_np} + self.outputs = {'Out': out_np, 'Noise': noise_np} + + def init_boundary(self): + self.lower = 0.1 + self.upper = 0.3 + + def init_dtype(self): + self.dtype = "float64" + + def init_input_shape(self): + self.x_shape = [2, 3, 4, 5] + + def init_attr(self): + self.attrs = {'lower': self.lower, "upper": self.upper, "is_test": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') if __name__ == "__main__": diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 63974d3bd524b..2fd6a7e537df4 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -22,7 +22,7 @@ import warnings from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import convert_np_dtype_to_dtype_ +from ...fluid.framework import convert_np_dtype_to_dtype_, default_main_program from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle @@ -545,7 +545,7 @@ def prelu(x, weight, data_format="NCHW", name=None): return out -def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): +def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): """ rrelu activation. @@ -566,7 +566,7 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): x (Tensor): The input Tensor with data type float 16 float32, float64. lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. - training (bool, optional): Current is training mode or others. Default is False. + training (bool, optional): Current is training mode or others. Default is True. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -588,14 +588,17 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') x = paddle.to_tensor(data) out = F.rrelu(x, 0.1, 0.3) - # [[[[-0.5 , 3. , -1. , 5. ], - # [ 3. , -1. , 5. , -1.5 ], - # [-1.75, -2. , 8. , 9. ]], - # [[ 1. , -0.5 , -0.75, 4. ], - # [-1.25, 6. , 7. , -2. ], - # [ 6. , 7. , 8. , 9. ]]]] + #[[[[-0.20000899 3. -0.8810822 5. ] + # [ 3. -0.55175185 5. -1.0776101 ] + # [-1.0680687 -1.9896201 8. 9. ]] + # [[ 1. -0.5238267 -0.65515125 4. ] + # [-1.3766339 6. 7. -2.3465784 ] + # [ 6. 7. 8. 9. ]]]] """ - check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], 'rrelu') + + if not in_dynamic_mode(): + check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], + 'rrelu') if not isinstance(lower, float) or not isinstance(upper, float): raise TypeError( @@ -617,23 +620,40 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=False, name=None): "The upper value must be no greater than one. Received: {}.".format( upper)) - if not training: - negative_slope = (lower + upper) / 2.0 - return leaky_relu(x, negative_slope, name) + is_test = not training + seed = None - if in_dynamic_mode(): - return _C_ops.rrelu(x, 'lower', lower, 'upper', upper) + if _in_legacy_dygraph(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed + out, noise = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test', + is_test, 'fix_seed', seed is not None, 'seed', + seed if seed is not None else 0) + return out + + def get_attrs(prog, lower, upper, is_test, seed): + if (seed is None or seed == 0) and prog.random_seed != 0: + seed = prog.random_seed + attrs = { + 'lower': lower, + 'upper': upper, + 'is_test': is_test, + 'fix_seed': seed is not None, + 'seed': seed if seed is not None else 0, + } + return attrs helper = LayerHelper('rrelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) noise = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = get_attrs(helper.main_program, lower, upper, is_test, seed) + helper.append_op( - type="rrelu", + type='rrelu', inputs={"X": x}, outputs={"Out": out, "Noise": noise}, - attrs={"lower": lower, - "upper": upper}) + attrs=attrs) return out diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 03cbca8b646e5..fb03fcb9915ce 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -480,12 +480,12 @@ class RReLU(Layer): x = paddle.to_tensor(data) m = paddle.nn.RReLU(0.1, 0.3) out = m(x) - # [[[[-0.5 , 3. , -1. , 5. ], - # [ 3. , -1. , 5. , -1.5 ], - # [-1.75, -2. , 8. , 9. ]], - # [[ 1. , -0.5 , -0.75, 4. ], - # [-1.25, 6. , 7. , -2. ], - # [ 6. , 7. , 8. , 9. ]]]] + #[[[[-0.20000899 3. -0.88108218 5. ] + # [ 3. -0.55175185 5. -1.07761011] + # [-1.06806871 -1.98962009 8. 9. ]] + # [[ 1. -0.52382672 -0.65515128 4. ] + # [-1.37663394 6. 7. -2.34657836] + # [ 6. 7. 8. 9. ]]]] """ def __init__(self, lower=1. / 8., upper=1. / 3., name=None): From f74eb8a4c333fec23a84a379baa8e2c388bcca1c Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 26 Apr 2022 07:16:23 +0000 Subject: [PATCH 05/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8Drrelu-sig=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/ops/compat/rrelu_sig.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc index c3ff52c1715bb..5f669e6210bf6 100644 --- a/paddle/phi/ops/compat/rrelu_sig.cc +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -26,7 +26,7 @@ KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RReluGradGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "rrelu_grad", {"X", "Noise", GradVarName("Out")}, {}, {GradVarName("X")}); + "rrelu_grad", {"X", "Noise", "Out@GRAD"}, {}, {"X@GRAD"}); } } // namespace phi From 71fdbab0ebd484d43f6cb4d108d8557c8c6c3e55 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 26 Apr 2022 16:05:52 +0000 Subject: [PATCH 06/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcpu=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/tests/unittests/test_rrelu_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index d8cd29e25c6ae..ce41764c108be 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -53,7 +53,7 @@ def setUp(self): self.upper_0 = 0.25 self.upper_1 = 0.33 - self.places = [fluid.CUDAPlace(0)] + self.places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): self.places.append(fluid.CUDAPlace(0)) From f7cf53bd51406ac169b154a45d47a85d8e9f6144 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Wed, 27 Apr 2022 08:57:40 +0800 Subject: [PATCH 07/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=8B=BC=E5=86=99?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 8 ++++---- python/paddle/nn/__init__.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index f22a07452d4a6..26f1ce84da28a 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -38,9 +38,9 @@ class RReluOp : public framework::OperatorWithKernel { class RReluOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "The input of RRelu op."); - AddOutput("Out", "The output of RRelu op."); - AddOutput("Noise", "The random sampled RRelu noise.") + AddInput("X", "The input of RReLU op."); + AddOutput("Out", "The output of RReLU op."); + AddOutput("Noise", "The random sampled RReLU noise.") .AsIntermediate() .AsExtra(); AddAttr("is_test", @@ -74,7 +74,7 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker { "'RRelu_upper' must be between 0.0 and 1.0.")); }); AddComment(R"DOC( -RRelu Operator. +RReLU Operator. Applies the randomized leaky rectified liner unit function, element-wise, as described in the paper: diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 45b9e4711cca0..b4be291b0697f 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -314,5 +314,5 @@ def weight_norm(*args): 'MaxUnPool3D', 'HingeEmbeddingLoss', 'Identity', - 'RRelu', + 'RReLU', ] From a0cd8222275d0f3adf13cbd91c622336c00dfa5a Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Wed, 27 Apr 2022 10:51:25 +0800 Subject: [PATCH 08/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9code=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/rrelu_grad_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/rrelu_grad_kernel.h b/paddle/phi/kernels/rrelu_grad_kernel.h index dbb8da874e27e..b6172fca10e53 100644 --- a/paddle/phi/kernels/rrelu_grad_kernel.h +++ b/paddle/phi/kernels/rrelu_grad_kernel.h @@ -25,4 +25,4 @@ void RReluGradKernel(const Context& dev_ctx, const DenseTensor& noise, const DenseTensor& out_grad, DenseTensor* x_grad); -} // namespace phi \ No newline at end of file +} // namespace phi From e5cf5478ec2e0849d9cfc6a821cce189197840c5 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Fri, 29 Apr 2022 07:22:32 +0000 Subject: [PATCH 09/16] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8Btimeout=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/fluid/tests/unittests/test_rrelu_op.py | 9 ++++----- python/paddle/nn/functional/activation.py | 1 + python/paddle/nn/layer/activation.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index d8cd29e25c6ae..4f4d0d7cc3aa2 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -17,9 +17,7 @@ import unittest import numpy as np import paddle.fluid as fluid -import six import paddle.fluid.core as core -from paddle.fluid import Program, program_guard from op_test import OpTest import paddle import paddle.nn.functional as F @@ -53,9 +51,10 @@ def setUp(self): self.upper_0 = 0.25 self.upper_1 = 0.33 - self.places = [fluid.CUDAPlace(0)] - if core.is_compiled_with_cuda(): - self.places.append(fluid.CUDAPlace(0)) + self.places = [ + fluid.CUDAPlace(0) + if core.is_compiled_with_cuda() else fluid.CPUPlace() + ] def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 2fd6a7e537df4..928fd47c20967 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -575,6 +575,7 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): Examples: .. code-block:: python + :name: rrelu-example import paddle import paddle.nn.functional as F diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 946e2d8445d0e..d32c403a010c1 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -465,6 +465,7 @@ class RReLU(Layer): Examples: .. code-block:: python + :name: RReLU-example import paddle import numpy as np From 28cd511d0e2f47953dc44008296e152cb8265d0a Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 30 Apr 2022 06:53:45 +0000 Subject: [PATCH 10/16] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 27 ++------- paddle/phi/infermeta/unary.cc | 9 +++ paddle/phi/infermeta/unary.h | 4 ++ paddle/phi/kernels/cpu/rrelu_kernel.cc | 10 +++- .../fluid/tests/unittests/test_rrelu_op.py | 55 +++++++++++++------ python/paddle/nn/functional/activation.py | 4 +- python/paddle/nn/layer/activation.py | 6 +- 7 files changed, 69 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index 26f1ce84da28a..a7b4f752ea734 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,26 +102,6 @@ where :math:`a` is randomly sampled from uniform distribution class RReluGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "rrelu"); - OP_INOUT_CHECK(ctx->HasInput("Noise"), "Input", "Noise", "rrelu"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - framework::GradVarName("Out"), "rrelu"); - - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), out_dims); - ctx->ShareLoD(framework::GradVarName("Out"), - /*->*/ framework::GradVarName("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } }; template @@ -150,4 +130,7 @@ REGISTER_OPERATOR(rrelu, ops::RReluOp, ops::RReluOpMaker, ops::RReluGradOpMaker, ops::RReluGradOpMaker, RReluInferShapeFunctor); -REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp); + +DECLARE_INFER_SHAPE_FUNCTOR(rrelu_grad, RReluGradInferShapeFunctor, + PD_INFER_META(phi::RReluGradInferMeta)); +REGISTER_OPERATOR(rrelu_grad, ops::RReluGradOp, RReluGradInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 22711c876b1e2..504c709229806 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1957,6 +1957,15 @@ void RReluInferMeta(const MetaTensor& x, } } +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& noise, + MetaTensor* x_grad) { + auto do_dims = out_grad.dims(); + x_grad->set_dims(do_dims); + x_grad->set_dtype(out_grad.dtype()); + x_grad->share_lod(out_grad); +} + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) { auto in_dims = x.dims(); PADDLE_ENFORCE_LT( diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 852286f06c100..5d9f2ce06444d 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -282,6 +282,10 @@ void RReluInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* noise); +void RReluGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& noise, + MetaTensor* x_grad); + void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc index 0d58141e2cbec..f1fdcd40da53c 100644 --- a/paddle/phi/kernels/cpu/rrelu_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -38,8 +38,8 @@ void RReluKernel(const Context& dev_ctx, int i = 0; if (is_test) { + T mid_val = static_cast((lower + upper) / 2.0); for (i = 0; i < numel; i++) { - T mid_val = static_cast((lower + upper) / 2.0); if (x_ptr[i] < zero) { o_ptr[i] = mid_val * x_ptr[i]; n_ptr[i] = mid_val; @@ -71,4 +71,10 @@ void RReluKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(rrelu, CPU, ALL_LAYOUT, phi::RReluKernel, float, double) {} +PD_REGISTER_KERNEL(rrelu, + CPU, + ALL_LAYOUT, + phi::RReluKernel, + float, + phi::dtype::float16, + double) {} diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index 4f4d0d7cc3aa2..cf78f4e5e249e 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -95,6 +95,7 @@ def test_static_graph_functional(self): name="x2", shape=self.x_np.shape, dtype="float64") out_1 = F.rrelu(x_1, self.lower_0, self.upper_0, training=False) out_2 = F.rrelu(x_2, self.lower_1, self.upper_1, training=False) + out_3 = F.rrelu(x_2, self.lower_1, self.upper_1, training=True) exe = paddle.static.Executor(place=place) res_1 = exe.run(fluid.default_main_program(), @@ -105,11 +106,17 @@ def test_static_graph_functional(self): feed={"x2": self.x_np}, fetch_list=out_2, use_prune=True) + res_3 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_np}, + fetch_list=out_3, + use_prune=True) out_ref_1 = ref_rrelu(self.x_np, self.lower_0, self.upper_0) out_ref_2 = ref_rrelu(self.x_np, self.lower_1, self.upper_1) self.assertEqual(np.allclose(out_ref_1, res_1), True) self.assertEqual(np.allclose(out_ref_2, res_2), True) + self.assertTrue( + check_output(self.x_np, res_3[0], self.lower_1, self.upper_1)) def test_static_graph_layer(self): '''test_static_graph_layer''' @@ -267,10 +274,14 @@ def error_lower_upper(): class RReluTest(OpTest): def setUp(self): self.op_type = "rrelu" - self.init_boundary() - self.init_dtype() - self.init_input_shape() - self.init_attr() + self.lower = 0.1 + self.upper = 0.3 + self.is_test = True + self.init_prams() + + def init_prams(self): + self.dtype = "float64" + self.x_shape = [2, 3, 4, 5] x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) out_np = ref_rrelu(x_np, self.lower, self.upper) @@ -279,19 +290,11 @@ def setUp(self): self.inputs = {'X': x_np} self.outputs = {'Out': out_np, 'Noise': noise_np} - - def init_boundary(self): - self.lower = 0.1 - self.upper = 0.3 - - def init_dtype(self): - self.dtype = "float64" - - def init_input_shape(self): - self.x_shape = [2, 3, 4, 5] - - def init_attr(self): - self.attrs = {'lower': self.lower, "upper": self.upper, "is_test": True} + self.attrs = { + 'lower': self.lower, + "upper": self.upper, + "is_test": self.is_test + } def test_check_output(self): self.check_output() @@ -300,5 +303,23 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class RReluTrainingTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.lower = 0.3 + self.upper = 0.3000009 + self.is_test = False + self.init_prams() + + +class RReluTrainingTest(OpTest): + def setUp(self): + self.op_type = "rrelu" + self.lower = 0.3 + self.upper = 0.3000009 + self.is_test = False + self.init_prams() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 58c175a0003a9..6ace7ed387676 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -590,8 +590,8 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): [[ 1.0, -2.0, -3.0, 4.0], [-5.0, 6.0, 7.0, -8.0], [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') - x = paddle.to_tensor(data) - out = F.rrelu(x, 0.1, 0.3) + input_tensor = paddle.to_tensor(data) + out = F.rrelu(input_tensor, 0.1, 0.3) #[[[[-0.20000899 3. -0.8810822 5. ] # [ 3. -0.55175185 5. -1.0776101 ] # [-1.0680687 -1.9896201 8. 9. ]] diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index d32c403a010c1..33b5e5b4e7fc7 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -478,9 +478,9 @@ class RReLU(Layer): [[ 1.0, -2.0, -3.0, 4.0], [-5.0, 6.0, 7.0, -8.0], [ 6.0, 7.0, 8.0, 9.0]]]], 'float64') - x = paddle.to_tensor(data) - m = paddle.nn.RReLU(0.1, 0.3) - out = m(x) + input_tensor = paddle.to_tensor(data) + rrelu_layer = paddle.nn.RReLU(0.1, 0.3) + output = rrelu_layer(input_tensor) #[[[[-0.20000899 3. -0.88108218 5. ] # [ 3. -0.55175185 5. -1.07761011] # [-1.06806871 -1.98962009 8. 9. ]] From a5fed7fecb3138b78d83f829f2932e2409936f71 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 5 May 2022 16:46:12 +0000 Subject: [PATCH 11/16] =?UTF-8?q?=E7=A7=BB=E9=99=A4seed,=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E9=9A=8F=E6=9C=BA=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 10 -- paddle/phi/infermeta/unary.cc | 2 - paddle/phi/infermeta/unary.h | 2 - paddle/phi/kernels/cpu/rrelu_kernel.cc | 5 +- paddle/phi/kernels/gpu/rrelu_funcs.h | 106 --------------- paddle/phi/kernels/gpu/rrelu_kernel.cu | 123 +++++++++++++----- paddle/phi/kernels/rrelu_kernel.h | 2 - paddle/phi/ops/compat/rrelu_sig.cc | 6 +- .../fluid/tests/unittests/test_rrelu_op.py | 1 + python/paddle/nn/functional/activation.py | 23 +--- 10 files changed, 94 insertions(+), 186 deletions(-) delete mode 100644 paddle/phi/kernels/gpu/rrelu_funcs.h diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index a7b4f752ea734..c543a088e9d7f 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -47,16 +47,6 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("fix_seed", - "A flag indicating whether to use a fixed seed to generate " - "random mask. NOTE: DO NOT set this flag to true in " - "training. Setting this flag to true is only useful in " - "unittest or for debug that always the same output units " - "will be dropped.") - .SetDefault(false) - .AsExtra(); - AddAttr("seed", "Rrelu random seed.").SetDefault(0).AsExtra(); - float default_lower = 1. / 8.; AddAttr("lower", "Lower bound of the uniform distribution.") .SetDefault(default_lower) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 504c709229806..c09a4b8f430aa 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1919,8 +1919,6 @@ void RReluInferMeta(const MetaTensor& x, float lower, float upper, bool is_test, - bool fix_seed, - int seed, MetaTensor* out, MetaTensor* noise) { auto x_dims = x.dims(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 5d9f2ce06444d..1ae63e28c9a51 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -277,8 +277,6 @@ void RReluInferMeta(const MetaTensor& x, float lower, float upper, bool is_test, - bool fix_seed, - int seed, MetaTensor* out, MetaTensor* noise); diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc index f1fdcd40da53c..4c6e30beddfa3 100644 --- a/paddle/phi/kernels/cpu/rrelu_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -26,8 +26,6 @@ void RReluKernel(const Context& dev_ctx, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise) { const T* x_ptr = x.data(); @@ -52,8 +50,7 @@ void RReluKernel(const Context& dev_ctx, return; } - int seed_data = fix_seed ? seed : 0; - auto engine = paddle::framework::GetCPURandomEngine(seed_data); + auto engine = paddle::framework::GetCPURandomEngine(0); std::uniform_real_distribution dist(lower, upper); diff --git a/paddle/phi/kernels/gpu/rrelu_funcs.h b/paddle/phi/kernels/gpu/rrelu_funcs.h deleted file mode 100644 index f768a56225506..0000000000000 --- a/paddle/phi/kernels/gpu/rrelu_funcs.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace phi { - -#define CUDA_NUM_THREADS 1024 - -inline static int PADDLE_GET_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -template -__global__ void RReluElementWiseKernel(const T *input, - T *output, - T *noise, - const T mid_val, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel) { - CUDA_KERNEL_LOOP(index, numel) { - T x = input[index]; - T zero = static_cast(0); - - if (is_test) { - if (x < zero) { - output[index] = mid_val * x; - noise[index] = mid_val; - } else { - output[index] = x; - noise[index] = 1.0; - } - - } else { - if (x < zero) { - thrust::minstd_rand rng; - rng.seed(seed_data); - thrust::uniform_real_distribution dist(lower, upper); - rng.discard(index); - T scale = static_cast(dist(rng)); - output[index] = scale * x; - noise[index] = scale; - } else { - output[index] = x; - noise[index] = 1.0; - } - } - } -} - -template -class RReluElementWiseDirectCUDAFunctor { - public: - void operator()(gpuStream_t stream, - const T *input, - T *output, - T *noise, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel); -}; - -template -void RReluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, - const T *input, - T *output, - T *noise, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel) { - T mid_val = static_cast((lower + upper) / 2.0); - RReluElementWiseKernel<<>>( - input, output, noise, mid_val, lower, upper, is_test, seed_data, numel); -} - -template class RReluElementWiseDirectCUDAFunctor; -template class RReluElementWiseDirectCUDAFunctor; -template class RReluElementWiseDirectCUDAFunctor; -} // namespace phi diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index d41d55aab7153..39582d5872a70 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -1,53 +1,104 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -#include "paddle/phi/kernels/rrelu_kernel.h" +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/rrelu_funcs.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/rrelu_kernel.h" namespace phi { +template +struct RReluTrainCudaFunctor { + public: + RReluTrainCudaFunctor(const T* in, T* out, T* noise) + : in_(in), out_(out), noise_(noise) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = noise_[idx] * x; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; +}; + +template +struct RReluTestCudaFunctor { + public: + RReluTestCudaFunctor(const T* in, T* out, T* noise, T mid_val) + : in_(in), out_(out), noise_(noise), mid_val_(mid_val) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = mid_val_ * x; + noise_[idx] = mid_val_; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; + T mid_val_; +}; + template -void RReluKernel(const Context& dev_ctx, +void RReluKernel(const Context& ctx, const DenseTensor& x, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise) { - const T* x_ptr = x.data(); - T* o_ptr = dev_ctx.template Alloc(out); - T* n_ptr = dev_ctx.template Alloc(noise); - - int numel = x.numel(); - auto dim = x.dims(); - RReluElementWiseDirectCUDAFunctor rrelu_element_wise; - - int seed_data = fix_seed ? seed : 0; - rrelu_element_wise(dev_ctx.stream(), - x_ptr, - o_ptr, - n_ptr, - lower, - upper, - is_test, - seed_data, - numel); + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + T* noise_data = ctx.template Alloc(noise); + auto size = x.numel(); + if (size <= 0) return; + + phi::funcs::ForRange for_range(ctx, size); + if (is_test) { + T mid_val = static_cast((lower + upper) / 2.0); + RReluTestCudaFunctor functor(x_data, out_data, noise_data, mid_val); + for_range(functor); + } else { + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(lower, upper); + funcs::distribution_and_transform(ctx, noise, dist, trans); + RReluTrainCudaFunctor functor(x_data, out_data, noise_data); + for_range(functor); + } } } // namespace phi diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h index 50807ffc93f6b..8deb52daaae13 100644 --- a/paddle/phi/kernels/rrelu_kernel.h +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -24,8 +24,6 @@ void RReluKernel(const Context& dev_ctx, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise); } // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc index 5f669e6210bf6..00cd705a24076 100644 --- a/paddle/phi/ops/compat/rrelu_sig.cc +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -17,10 +17,8 @@ namespace phi { KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("rrelu", - {"X"}, - {"lower", "upper", "is_test", "fix_seed", "seed"}, - {"Out", "Noise"}); + return KernelSignature( + "rrelu", {"X"}, {"lower", "upper", "is_test"}, {"Out", "Noise"}); } KernelSignature RReluGradGradOpArgumentMapping( diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index cf78f4e5e249e..9d33ce085b7f7 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -187,6 +187,7 @@ def test_dygraph(self): paddle.enable_static() def test_error_functional(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. self.assertRaises( diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6ace7ed387676..3234ce44645bc 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -22,7 +22,7 @@ import warnings from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import convert_np_dtype_to_dtype_, default_main_program +from ...fluid.framework import convert_np_dtype_to_dtype_ from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle @@ -625,33 +625,16 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): upper)) is_test = not training - seed = None if _in_legacy_dygraph(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed out, noise = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test', - is_test, 'fix_seed', seed is not None, 'seed', - seed if seed is not None else 0) + is_test) return out - def get_attrs(prog, lower, upper, is_test, seed): - if (seed is None or seed == 0) and prog.random_seed != 0: - seed = prog.random_seed - attrs = { - 'lower': lower, - 'upper': upper, - 'is_test': is_test, - 'fix_seed': seed is not None, - 'seed': seed if seed is not None else 0, - } - return attrs - helper = LayerHelper('rrelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) noise = helper.create_variable_for_type_inference(dtype=x.dtype) - attrs = get_attrs(helper.main_program, lower, upper, is_test, seed) - + attrs = {'lower': lower, 'upper': upper, 'is_test': is_test} helper.append_op( type='rrelu', inputs={"X": x}, From 916b5b8fc23dc5e73060fcbcb7215279aa0d7ac4 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Fri, 6 May 2022 14:06:48 +0000 Subject: [PATCH 12/16] update en doc for rrelu --- python/paddle/nn/functional/activation.py | 39 +++++++++++------------ python/paddle/nn/layer/activation.py | 35 ++++++++++---------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 3234ce44645bc..173703587b37c 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -552,24 +552,24 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): """ rrelu activation. - .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: - https://arxiv.org/abs/1505.00853 + `Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 .. math:: + \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} - where :math:`a` is randomly sampled from uniform distribution - :math:`\mathcal{U}(\text{lower}, \text{upper})`. + where :math:`x` is the input tensor, + :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), Parameters: - x (Tensor): The input Tensor with data type float 16 float32, float64. - lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. - upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. - training (bool, optional): Current is training mode or others. Default is True. + x (Tensor): The input Tensor with data type float16, float32, float64. + lower (float, optional): The lower bound of uniform distribution. Default: 0.125. + upper (float, optional): The upper bound of uniform distribution. Default: 0.333. + training (bool, optional): Current mode is in training or others. Default is True. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -582,15 +582,14 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): import paddle import paddle.nn.functional as F - import numpy as np - data = np.array([[[[-2.0, 3.0, -4.0, 5.0], - [ 3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[ 1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [ 6.0, 7.0, 8.0, 9.0]]]], 'float32') - input_tensor = paddle.to_tensor(data) + input_tensor = paddle.to_tensor([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], dtype='float32') + out = F.rrelu(input_tensor, 0.1, 0.3) #[[[[-0.20000899 3. -0.8810822 5. ] # [ 3. -0.55175185 5. -1.0776101 ] diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 33b5e5b4e7fc7..18006da3538e5 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -440,22 +440,22 @@ class RReLU(Layer): """ rrelu activation. - `Empirical Evaluation of Rectified Activations in Convolutional Network`: - https://arxiv.org/abs/1505.00853 + `Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 .. math:: + \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} - where :math:`a` is randomly sampled from uniform distribution - :math:`\mathcal{U}(\text{lower}, \text{upper})`. + where :math:`x` is the input tensor, + :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), Parameters: - lower (float, optional): The lower bound of uniform distribution. Default: :math:`\frac{1}{8}`. - upper (float, optional): The upper bound of uniform distribution. Default: :math:`\frac{1}{3}`. + lower (float, optional): The lower bound of uniform distribution. Default: 0.125. + upper (float, optional): The upper bound of uniform distribution. Default: 0.333. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -468,17 +468,14 @@ class RReLU(Layer): :name: RReLU-example import paddle - import numpy as np - paddle.set_default_dtype("float64") + input_tensor = paddle.to_tensor([[[[-2.0, 3.0, -4.0, 5.0], + [ 3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[ 1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [ 6.0, 7.0, 8.0, 9.0]]]], dtype='float32') - data = np.array([[[[-2.0, 3.0, -4.0, 5.0], - [ 3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[ 1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [ 6.0, 7.0, 8.0, 9.0]]]], 'float64') - input_tensor = paddle.to_tensor(data) rrelu_layer = paddle.nn.RReLU(0.1, 0.3) output = rrelu_layer(input_tensor) #[[[[-0.20000899 3. -0.88108218 5. ] From 726a5d5ce154770e009c163028ee797b8d6e91f8 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 10 May 2022 08:39:45 +0000 Subject: [PATCH 13/16] fix rrelu en docs, test=document_fix --- python/paddle/nn/functional/activation.py | 32 ++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 173703587b37c..dd9ab69fdf824 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -552,19 +552,39 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): """ rrelu activation. - `Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 + Applies the randomized leaky rectified liner unit function, as described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network ` + + During training, randomly samples the negative slope for activation values as described below: .. math:: - \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + rrelu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + a * x, & & otherwise \\ + \end{array} + \right. where :math:`x` is the input tensor, :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), + In the test phase, the negative slope will take the average value of :math:`lower` and :math:`upper`: + + .. math:: + + rrelu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + (lower + upper) * 0.5 * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`lower` and :math:`upper` are the bounds of uniform distribution. + Parameters: x (Tensor): The input Tensor with data type float16, float32, float64. lower (float, optional): The lower bound of uniform distribution. Default: 0.125. From 4502a018ca21dde65a2099050d5694c9e3781abe Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Tue, 10 May 2022 09:38:19 +0000 Subject: [PATCH 14/16] add paper link for en docs, test=document_fix --- python/paddle/nn/functional/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index dd9ab69fdf824..bc691ea72d5ca 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -553,7 +553,7 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): rrelu activation. Applies the randomized leaky rectified liner unit function, as described in the paper: - `Empirical Evaluation of Rectified Activations in Convolutional Network ` + `Empirical Evaluation of Rectified Activations in Convolutional Network `_ During training, randomly samples the negative slope for activation values as described below: From 2132bf55f45ec8499c784f5e9f19f436769fc066 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Sat, 14 May 2022 08:31:13 +0000 Subject: [PATCH 15/16] udpate en doc --- python/paddle/nn/functional/activation.py | 3 +- python/paddle/nn/layer/activation.py | 35 ++++++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index bc691ea72d5ca..a37ed6f91edf2 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -552,7 +552,8 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): """ rrelu activation. - Applies the randomized leaky rectified liner unit function, as described in the paper: + Applies the randomized leaky rectified liner unit function to improve generalization performance, + as described in the paper: `Empirical Evaluation of Rectified Activations in Convolutional Network `_ During training, randomly samples the negative slope for activation values as described below: diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 18006da3538e5..65ddf3184b33d 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -438,21 +438,42 @@ def extra_repr(self): class RReLU(Layer): """ - rrelu activation. + RReLU activation layer. - `Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 + Applies the randomized leaky rectified liner unit function to improve generalization performance, + as described in the paper: + `Empirical Evaluation of Rectified Activations in Convolutional Network `_ + + During training, randomly samples the negative slope for activation values as described below: .. math:: - \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + RReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + a * x, & & otherwise \\ + \end{array} + \right. where :math:`x` is the input tensor, :math:`a` is randomly sampled from uniform distribution in range (:math:`lower`, :math:`upper`), + In the test phase, the negative slope will take the average value of :math:`lower` and :math:`upper`: + + .. math:: + + RReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + (lower + upper) * 0.5 * x, & & otherwise \\ + \end{array} + \right. + + where :math:`x` is the input tensor, + :math:`lower` and :math:`upper` are the bounds of uniform distribution. + Parameters: lower (float, optional): The lower bound of uniform distribution. Default: 0.125. upper (float, optional): The upper bound of uniform distribution. Default: 0.333. From 8f626c2713df285697dc2608f29e456bbfc8bb4b Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Mon, 23 May 2022 03:28:21 +0000 Subject: [PATCH 16/16] add r,test=document_fix --- python/paddle/nn/functional/activation.py | 2 +- python/paddle/nn/layer/activation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index a37ed6f91edf2..7af7ae8ddd7ab 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -549,7 +549,7 @@ def prelu(x, weight, data_format="NCHW", name=None): def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): - """ + r""" rrelu activation. Applies the randomized leaky rectified liner unit function to improve generalization performance, diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 32c77104f53fb..1a3768e919042 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -437,7 +437,7 @@ def extra_repr(self): class RReLU(Layer): - """ + r""" RReLU activation layer. Applies the randomized leaky rectified liner unit function to improve generalization performance,