Skip to content

Commit

Permalink
移除seed, 优化随机函数
Browse files Browse the repository at this point in the history
  • Loading branch information
thunder95 committed May 5, 2022
1 parent 28cd511 commit a5fed7f
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 186 deletions.
10 changes: 0 additions & 10 deletions paddle/fluid/operators/rrelu_op.cc
Expand Up @@ -47,16 +47,6 @@ class RReluOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
"training. Setting this flag to true is only useful in "
"unittest or for debug that always the same output units "
"will be dropped.")
.SetDefault(false)
.AsExtra();
AddAttr<int>("seed", "Rrelu random seed.").SetDefault(0).AsExtra();

float default_lower = 1. / 8.;
AddAttr<float>("lower", "Lower bound of the uniform distribution.")
.SetDefault(default_lower)
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/infermeta/unary.cc
Expand Up @@ -1919,8 +1919,6 @@ void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
bool fix_seed,
int seed,
MetaTensor* out,
MetaTensor* noise) {
auto x_dims = x.dims();
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/infermeta/unary.h
Expand Up @@ -277,8 +277,6 @@ void RReluInferMeta(const MetaTensor& x,
float lower,
float upper,
bool is_test,
bool fix_seed,
int seed,
MetaTensor* out,
MetaTensor* noise);

Expand Down
5 changes: 1 addition & 4 deletions paddle/phi/kernels/cpu/rrelu_kernel.cc
Expand Up @@ -26,8 +26,6 @@ void RReluKernel(const Context& dev_ctx,
const float lower,
const float upper,
bool is_test,
bool fix_seed,
int seed,
DenseTensor* out,
DenseTensor* noise) {
const T* x_ptr = x.data<T>();
Expand All @@ -52,8 +50,7 @@ void RReluKernel(const Context& dev_ctx,
return;
}

int seed_data = fix_seed ? seed : 0;
auto engine = paddle::framework::GetCPURandomEngine(seed_data);
auto engine = paddle::framework::GetCPURandomEngine(0);

std::uniform_real_distribution<float> dist(lower, upper);

Expand Down
106 changes: 0 additions & 106 deletions paddle/phi/kernels/gpu/rrelu_funcs.h

This file was deleted.

123 changes: 87 additions & 36 deletions paddle/phi/kernels/gpu/rrelu_kernel.cu
@@ -1,53 +1,104 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#include "paddle/phi/kernels/rrelu_kernel.h"
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/rrelu_funcs.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/rrelu_kernel.h"

namespace phi {

template <typename T>
struct RReluTrainCudaFunctor {
public:
RReluTrainCudaFunctor(const T* in, T* out, T* noise)
: in_(in), out_(out), noise_(noise) {
zero_ = static_cast<T>(0);
}

__device__ void operator()(int64_t idx) {
T x = in_[idx];
if (x < zero_) {
out_[idx] = noise_[idx] * x;
} else {
out_[idx] = x;
noise_[idx] = 1.0;
}
}

private:
const T* in_;
T* out_;
T* noise_;
T zero_;
};

template <typename T>
struct RReluTestCudaFunctor {
public:
RReluTestCudaFunctor(const T* in, T* out, T* noise, T mid_val)
: in_(in), out_(out), noise_(noise), mid_val_(mid_val) {
zero_ = static_cast<T>(0);
}

__device__ void operator()(int64_t idx) {
T x = in_[idx];
if (x < zero_) {
out_[idx] = mid_val_ * x;
noise_[idx] = mid_val_;
} else {
out_[idx] = x;
noise_[idx] = 1.0;
}
}

private:
const T* in_;
T* out_;
T* noise_;
T zero_;
T mid_val_;
};

template <typename T, typename Context>
void RReluKernel(const Context& dev_ctx,
void RReluKernel(const Context& ctx,
const DenseTensor& x,
const float lower,
const float upper,
bool is_test,
bool fix_seed,
int seed,
DenseTensor* out,
DenseTensor* noise) {
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
T* n_ptr = dev_ctx.template Alloc<T>(noise);

int numel = x.numel();
auto dim = x.dims();
RReluElementWiseDirectCUDAFunctor<T> rrelu_element_wise;

int seed_data = fix_seed ? seed : 0;
rrelu_element_wise(dev_ctx.stream(),
x_ptr,
o_ptr,
n_ptr,
lower,
upper,
is_test,
seed_data,
numel);
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
T* noise_data = ctx.template Alloc<T>(noise);
auto size = x.numel();
if (size <= 0) return;

phi::funcs::ForRange<Context> for_range(ctx, size);
if (is_test) {
T mid_val = static_cast<T>((lower + upper) / 2.0);
RReluTestCudaFunctor<T> functor(x_data, out_data, noise_data, mid_val);
for_range(functor);
} else {
using MT = typename kps::details::MPTypeTrait<T>::Type;
funcs::uniform_distribution<MT> dist;
funcs::uniform_real_transform<MT> trans(lower, upper);
funcs::distribution_and_transform<T>(ctx, noise, dist, trans);
RReluTrainCudaFunctor<T> functor(x_data, out_data, noise_data);
for_range(functor);
}
}

} // namespace phi
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/rrelu_kernel.h
Expand Up @@ -24,8 +24,6 @@ void RReluKernel(const Context& dev_ctx,
const float lower,
const float upper,
bool is_test,
bool fix_seed,
int seed,
DenseTensor* out,
DenseTensor* noise);
} // namespace phi
6 changes: 2 additions & 4 deletions paddle/phi/ops/compat/rrelu_sig.cc
Expand Up @@ -17,10 +17,8 @@
namespace phi {

KernelSignature RReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("rrelu",
{"X"},
{"lower", "upper", "is_test", "fix_seed", "seed"},
{"Out", "Noise"});
return KernelSignature(
"rrelu", {"X"}, {"lower", "upper", "is_test"}, {"Out", "Noise"});
}

KernelSignature RReluGradGradOpArgumentMapping(
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/test_rrelu_op.py
Expand Up @@ -187,6 +187,7 @@ def test_dygraph(self):
paddle.enable_static()

def test_error_functional(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(
Expand Down
23 changes: 3 additions & 20 deletions python/paddle/nn/functional/activation.py
Expand Up @@ -22,7 +22,7 @@

import warnings
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import convert_np_dtype_to_dtype_, default_main_program
from ...fluid.framework import convert_np_dtype_to_dtype_
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle
Expand Down Expand Up @@ -625,33 +625,16 @@ def rrelu(x, lower=1. / 8., upper=1. / 3., training=True, name=None):
upper))

is_test = not training
seed = None

if _in_legacy_dygraph():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
out, noise = _C_ops.rrelu(x, 'lower', lower, 'upper', upper, 'is_test',
is_test, 'fix_seed', seed is not None, 'seed',
seed if seed is not None else 0)
is_test)
return out

def get_attrs(prog, lower, upper, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
seed = prog.random_seed
attrs = {
'lower': lower,
'upper': upper,
'is_test': is_test,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0,
}
return attrs

helper = LayerHelper('rrelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
noise = helper.create_variable_for_type_inference(dtype=x.dtype)
attrs = get_attrs(helper.main_program, lower, upper, is_test, seed)

attrs = {'lower': lower, 'upper': upper, 'is_test': is_test}
helper.append_op(
type='rrelu',
inputs={"X": x},
Expand Down

0 comments on commit a5fed7f

Please sign in to comment.