From a5fed7fecb3138b78d83f829f2932e2409936f71 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Thu, 5 May 2022 16:46:12 +0000 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4seed,=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/operators/rrelu_op.cc | 10 -- paddle/phi/infermeta/unary.cc | 2 - paddle/phi/infermeta/unary.h | 2 - paddle/phi/kernels/cpu/rrelu_kernel.cc | 5 +- paddle/phi/kernels/gpu/rrelu_funcs.h | 106 --------------- paddle/phi/kernels/gpu/rrelu_kernel.cu | 123 +++++++++++++----- paddle/phi/kernels/rrelu_kernel.h | 2 - paddle/phi/ops/compat/rrelu_sig.cc | 6 +- .../fluid/tests/unittests/test_rrelu_op.py | 1 + python/paddle/nn/functional/activation.py | 23 +--- 10 files changed, 94 insertions(+), 186 deletions(-) delete mode 100644 paddle/phi/kernels/gpu/rrelu_funcs.h diff --git a/paddle/fluid/operators/rrelu_op.cc b/paddle/fluid/operators/rrelu_op.cc index a7b4f752ea734..c543a088e9d7f 100644 --- a/paddle/fluid/operators/rrelu_op.cc +++ b/paddle/fluid/operators/rrelu_op.cc @@ -47,16 +47,6 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); - AddAttr("fix_seed", - "A flag indicating whether to use a fixed seed to generate " - "random mask. NOTE: DO NOT set this flag to true in " - "training. Setting this flag to true is only useful in " - "unittest or for debug that always the same output units " - "will be dropped.") - .SetDefault(false) - .AsExtra(); - AddAttr("seed", "Rrelu random seed.").SetDefault(0).AsExtra(); - float default_lower = 1. / 8.; AddAttr("lower", "Lower bound of the uniform distribution.") .SetDefault(default_lower) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 504c709229806..c09a4b8f430aa 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1919,8 +1919,6 @@ void RReluInferMeta(const MetaTensor& x, float lower, float upper, bool is_test, - bool fix_seed, - int seed, MetaTensor* out, MetaTensor* noise) { auto x_dims = x.dims(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 5d9f2ce06444d..1ae63e28c9a51 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -277,8 +277,6 @@ void RReluInferMeta(const MetaTensor& x, float lower, float upper, bool is_test, - bool fix_seed, - int seed, MetaTensor* out, MetaTensor* noise); diff --git a/paddle/phi/kernels/cpu/rrelu_kernel.cc b/paddle/phi/kernels/cpu/rrelu_kernel.cc index f1fdcd40da53c..4c6e30beddfa3 100644 --- a/paddle/phi/kernels/cpu/rrelu_kernel.cc +++ b/paddle/phi/kernels/cpu/rrelu_kernel.cc @@ -26,8 +26,6 @@ void RReluKernel(const Context& dev_ctx, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise) { const T* x_ptr = x.data(); @@ -52,8 +50,7 @@ void RReluKernel(const Context& dev_ctx, return; } - int seed_data = fix_seed ? seed : 0; - auto engine = paddle::framework::GetCPURandomEngine(seed_data); + auto engine = paddle::framework::GetCPURandomEngine(0); std::uniform_real_distribution dist(lower, upper); diff --git a/paddle/phi/kernels/gpu/rrelu_funcs.h b/paddle/phi/kernels/gpu/rrelu_funcs.h deleted file mode 100644 index f768a56225506..0000000000000 --- a/paddle/phi/kernels/gpu/rrelu_funcs.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace phi { - -#define CUDA_NUM_THREADS 1024 - -inline static int PADDLE_GET_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -template -__global__ void RReluElementWiseKernel(const T *input, - T *output, - T *noise, - const T mid_val, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel) { - CUDA_KERNEL_LOOP(index, numel) { - T x = input[index]; - T zero = static_cast(0); - - if (is_test) { - if (x < zero) { - output[index] = mid_val * x; - noise[index] = mid_val; - } else { - output[index] = x; - noise[index] = 1.0; - } - - } else { - if (x < zero) { - thrust::minstd_rand rng; - rng.seed(seed_data); - thrust::uniform_real_distribution dist(lower, upper); - rng.discard(index); - T scale = static_cast(dist(rng)); - output[index] = scale * x; - noise[index] = scale; - } else { - output[index] = x; - noise[index] = 1.0; - } - } - } -} - -template -class RReluElementWiseDirectCUDAFunctor { - public: - void operator()(gpuStream_t stream, - const T *input, - T *output, - T *noise, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel); -}; - -template -void RReluElementWiseDirectCUDAFunctor::operator()(gpuStream_t stream, - const T *input, - T *output, - T *noise, - const float lower, - const float upper, - const bool is_test, - const int seed_data, - size_t numel) { - T mid_val = static_cast((lower + upper) / 2.0); - RReluElementWiseKernel<<>>( - input, output, noise, mid_val, lower, upper, is_test, seed_data, numel); -} - -template class RReluElementWiseDirectCUDAFunctor; -template class RReluElementWiseDirectCUDAFunctor; -template class RReluElementWiseDirectCUDAFunctor; -} // namespace phi diff --git a/paddle/phi/kernels/gpu/rrelu_kernel.cu b/paddle/phi/kernels/gpu/rrelu_kernel.cu index d41d55aab7153..39582d5872a70 100644 --- a/paddle/phi/kernels/gpu/rrelu_kernel.cu +++ b/paddle/phi/kernels/gpu/rrelu_kernel.cu @@ -1,53 +1,104 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -#include "paddle/phi/kernels/rrelu_kernel.h" +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/rrelu_funcs.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/rrelu_kernel.h" namespace phi { +template +struct RReluTrainCudaFunctor { + public: + RReluTrainCudaFunctor(const T* in, T* out, T* noise) + : in_(in), out_(out), noise_(noise) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = noise_[idx] * x; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; +}; + +template +struct RReluTestCudaFunctor { + public: + RReluTestCudaFunctor(const T* in, T* out, T* noise, T mid_val) + : in_(in), out_(out), noise_(noise), mid_val_(mid_val) { + zero_ = static_cast(0); + } + + __device__ void operator()(int64_t idx) { + T x = in_[idx]; + if (x < zero_) { + out_[idx] = mid_val_ * x; + noise_[idx] = mid_val_; + } else { + out_[idx] = x; + noise_[idx] = 1.0; + } + } + + private: + const T* in_; + T* out_; + T* noise_; + T zero_; + T mid_val_; +}; + template -void RReluKernel(const Context& dev_ctx, +void RReluKernel(const Context& ctx, const DenseTensor& x, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise) { - const T* x_ptr = x.data(); - T* o_ptr = dev_ctx.template Alloc(out); - T* n_ptr = dev_ctx.template Alloc(noise); - - int numel = x.numel(); - auto dim = x.dims(); - RReluElementWiseDirectCUDAFunctor rrelu_element_wise; - - int seed_data = fix_seed ? seed : 0; - rrelu_element_wise(dev_ctx.stream(), - x_ptr, - o_ptr, - n_ptr, - lower, - upper, - is_test, - seed_data, - numel); + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + T* noise_data = ctx.template Alloc(noise); + auto size = x.numel(); + if (size <= 0) return; + + phi::funcs::ForRange for_range(ctx, size); + if (is_test) { + T mid_val = static_cast((lower + upper) / 2.0); + RReluTestCudaFunctor functor(x_data, out_data, noise_data, mid_val); + for_range(functor); + } else { + using MT = typename kps::details::MPTypeTrait::Type; + funcs::uniform_distribution dist; + funcs::uniform_real_transform trans(lower, upper); + funcs::distribution_and_transform(ctx, noise, dist, trans); + RReluTrainCudaFunctor functor(x_data, out_data, noise_data); + for_range(functor); + } } } // namespace phi diff --git a/paddle/phi/kernels/rrelu_kernel.h b/paddle/phi/kernels/rrelu_kernel.h index 50807ffc93f6b..8deb52daaae13 100644 --- a/paddle/phi/kernels/rrelu_kernel.h +++ b/paddle/phi/kernels/rrelu_kernel.h @@ -24,8 +24,6 @@ void RReluKernel(const Context& dev_ctx, const float lower, const float upper, bool is_test, - bool fix_seed, - int seed, DenseTensor* out, DenseTensor* noise); } // namespace phi diff --git a/paddle/phi/ops/compat/rrelu_sig.cc b/paddle/phi/ops/compat/rrelu_sig.cc index 5f669e6210bf6..00cd705a24076 100644 --- a/paddle/phi/ops/compat/rrelu_sig.cc +++ b/paddle/phi/ops/compat/rrelu_sig.cc @@ -17,10 +17,8 @@ namespace phi { KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("rrelu", - {"X"}, - {"lower", "upper", "is_test", "fix_seed", "seed"}, - {"Out", "Noise"}); + return KernelSignature( + "rrelu", {"X"}, {"lower", "upper", "is_test"}, {"Out", "Noise"}); } KernelSignature RReluGradGradOpArgumentMapping( diff --git a/python/paddle/fluid/tests/unittests/test_rrelu_op.py b/python/paddle/fluid/tests/unittests/test_rrelu_op.py index cf78f4e5e249e..9d33ce085b7f7 100644 --- a/python/paddle/fluid/tests/unittests/test_rrelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_rrelu_op.py @@ -187,6 +187,7 @@ def test_dygraph(self): paddle.enable_static() def test_error_functional(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. self.assertRaises( diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6ace7ed387676..3234ce44645bc 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -22,7 +22,7 @@ import warnings from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import convert_np_dtype_to_dtype_, default_main_program +from ...fluid.framework import convert_np_dtype_to_dtype_ from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle @@ -625,33 +625,16 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None): upper)) is_test = not training - seed = None if _in_legacy_dygraph(): - if default_main_program().random_seed != 0: - seed = default_main_program().random_seed out, noise = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test', - is_test, 'fix_seed', seed is not None, 'seed', - seed if seed is not None else 0) + is_test) return out - def get_attrs(prog, lower, upper, is_test, seed): - if (seed is None or seed == 0) and prog.random_seed != 0: - seed = prog.random_seed - attrs = { - 'lower': lower, - 'upper': upper, - 'is_test': is_test, - 'fix_seed': seed is not None, - 'seed': seed if seed is not None else 0, - } - return attrs - helper = LayerHelper('rrelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) noise = helper.create_variable_for_type_inference(dtype=x.dtype) - attrs = get_attrs(helper.main_program, lower, upper, is_test, seed) - + attrs = {'lower': lower, 'upper': upper, 'is_test': is_test} helper.append_op( type='rrelu', inputs={"X": x},