Skip to content

Commit

Permalink
[Yaml]add gaussian_random yaml and test case (#41312)
Browse files Browse the repository at this point in the history
* add guassian random yaml

* add gaussian_random yaml and test case

* fix error modify of full yaml

* import in_dygraph_mode

* import _in_legacy_dygraph

* add place arg in api

* import __current_expected_place

* fix test_egr_python_api failed case

* add test case

* add cast for NormalInitializer

* fix test error

* fix test error

* rm unsed check code

* fix test error in test_initializer_nn

* modify by review
  • Loading branch information
MingMingShangTian committed Apr 12, 2022
1 parent 18f569c commit 4819ab4
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 18 deletions.
75 changes: 64 additions & 11 deletions python/paddle/fluid/initializer.py
Expand Up @@ -331,29 +331,70 @@ def __call__(self, var, block=None):
["uint16", "float16", "float32", "float64"],
"guassian_random")

# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['normal_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var

if self._seed == 0:
self._seed = block.program.random_seed

if framework._non_static_mode():
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
var.shape, self._mean, self._std_dev, self._seed, out_dtype,
place)
out_var._share_underline_tensor_to(var)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.final_state_cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
return None

if _in_legacy_dygraph():
out_var = _C_ops.gaussian_random(
'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean,
'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean,
'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False)
out_var._share_underline_tensor_to(var)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
return None
else:
op = block.append_op(
type="gaussian_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": var.dtype,
"dtype": out_dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
"use_mkldnn": False
},
stop_gradient=True)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op

Expand Down Expand Up @@ -567,9 +608,15 @@ def __call__(self, var, block=None):
self._seed, 'dtype', out_dtype)
else:
std = math.sqrt(2.0 / float(fan_in + fan_out))
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0,
'std', std, 'seed', self._seed)

if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
out_var.shape, 0.0, std, self._seed, out_dtype, place)
else:
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0,
'std', std, 'seed', self._seed)

if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
Expand Down Expand Up @@ -720,9 +767,15 @@ def __call__(self, var, block=None):
int(out_dtype))
else:
std = math.sqrt(2.0 / float(fan_in))
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype',
int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed)
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
out_var.shape, 0.0, std, self._seed, out_dtype, place)
else:
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype',
int(out_dtype), 'mean', 0.0, 'std', std, 'seed',
self._seed)

if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
Expand Down
11 changes: 10 additions & 1 deletion python/paddle/fluid/layers/nn.py
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid.framework import _in_legacy_dygraph
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, _non_static_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator, static_only, _global_flags, _in_legacy_dygraph, in_dygraph_mode
from ..framework import _current_expected_place
from .. import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
Expand Down Expand Up @@ -10970,7 +10971,15 @@ def gaussian_random(shape,
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if _non_static_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
place = _current_expected_place()
return _C_ops.final_state_gaussian_random(shape,
float(mean),
float(std), seed, dtype,
place)

if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
return _C_ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
Expand Down
3 changes: 0 additions & 3 deletions python/paddle/fluid/tests/unittests/test_egr_python_api.py
Expand Up @@ -251,9 +251,6 @@ def constructor(self, place):
self.assertTrue(egr_tensor12.place._equals(paddle.fluid.CPUPlace()))
self.assertTrue(np.array_equal(egr_tensor12.numpy(), x))

egr_tensor13 = paddle.randn([2, 2])
self.assertTrue("eager_tmp" in egr_tensor13.name)

with self.assertRaisesRegexp(
ValueError, "The shape of Parameter should not be None"):
eager_param = EagerParamBase(shape=None, dtype="float32")
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
Expand Up @@ -23,12 +23,14 @@
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float
from paddle.fluid.framework import _test_eager_guard
import paddle


class TestGaussianRandomOp(OpTest):
def setUp(self):
self.op_type = "gaussian_random"
self.python_api = paddle.normal
self.set_attrs()
self.inputs = {}
self.use_mkldnn = False
Expand All @@ -50,6 +52,10 @@ def set_attrs(self):
def test_check_output(self):
self.check_output_customized(self.verify_output)

def test_eager(self):
with _test_eager_guard():
self.test_check_output()

def verify_output(self, outs):
self.assertEqual(outs[0].shape, (123, 92))
hist, _ = np.histogram(outs[0], range=(-3, 5))
Expand All @@ -70,6 +76,7 @@ def verify_output(self, outs):
class TestGaussianRandomBF16Op(OpTest):
def setUp(self):
self.op_type = "gaussian_random"
self.python_api = paddle.normal
self.set_attrs()
self.inputs = {}
self.use_mkldnn = False
Expand All @@ -93,6 +100,10 @@ def test_check_output(self):
self.check_output_with_place_customized(
self.verify_output, place=core.CUDAPlace(0))

def test_eager(self):
with _test_eager_guard():
self.test_check_output()

def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (123, 92))
Expand Down
64 changes: 63 additions & 1 deletion python/paddle/fluid/tests/unittests/test_initializer.py
Expand Up @@ -244,7 +244,7 @@ def test_normal_initializer(self, dtype="float32"):
lod_level=0,
name="param",
initializer=initializer.NormalInitializer(2.3, 1.9, 123))
num_ops = 1
num_ops = 2 if (dtype == "float16" or dtype == "uint16") else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
Expand Down Expand Up @@ -685,6 +685,68 @@ def test_uniform_initializer(self, dtype="float32"):
self.func_uniform_initializer()


class TestXavierInitializerDygraph(unittest.TestCase):
def func_xvarier_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()

tensor = paddle.zeros([1024, 1024, 16])
tensor.stop_gradient = False

xavier_ = paddle.fluid.initializer.XavierInitializer(
uniform=False, fan_in=3, fan_out=5)
xavier_(tensor)

hist, _ = output_hist(tensor.numpy())

hist2, _ = output_hist(
np.random.normal(0, np.sqrt(2.0 / (3 + 5)), [1024, 1024, 16]))

self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.01),
"hist: " + str(hist) + " hist2: " + str(hist2))
paddle.enable_static()

def test_xavier_initializer(self, dtype="float32"):
with framework._test_eager_guard():
self.func_xvarier_initializer()
self.func_xvarier_initializer()


class TestMSRAInitializerDygraph(unittest.TestCase):
def func_msra_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()

tensor = paddle.zeros([1024, 1024, 16])
tensor.stop_gradient = False

msra_ = paddle.fluid.initializer.MSRAInitializer(
uniform=False, fan_in=4)
msra_(tensor)

hist, _ = output_hist(tensor.numpy())

hist2, _ = output_hist(
np.random.normal(0, np.sqrt(2.0 / (4)), [1024, 1024, 16]))

self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.01),
"hist: " + str(hist) + " hist2: " + str(hist2))
paddle.enable_static()

def test_msra_initializer(self, dtype="float32"):
with framework._test_eager_guard():
self.func_msra_initializer()
self.func_msra_initializer()


class TesetconsistencyOfDynamicAndStaticGraph(unittest.TestCase):
def func_order(self):
paddle.set_device('cpu')
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_initializer_nn.py
Expand Up @@ -400,7 +400,7 @@ def test_normal_initializer(self, dtype="float32"):
lod_level=0,
name="param",
initializer=initializer.Normal(2.3, 1.9))
num_ops = 1
num_ops = 2 if dtype in ["float16", "uint16"] else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/tensor/random.py
Expand Up @@ -239,7 +239,15 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if paddle.in_dynamic_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
place = _current_expected_place()
return _C_ops.final_state_gaussian_random(shape,
float(mean),
float(std), seed, dtype,
place)

if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
return _C_ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Expand Up @@ -782,6 +782,18 @@
kernel :
func : gather_tree

- api : gaussian_random
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor
infer_meta :
func : GaussianRandomInferMeta
param : [shape, mean, std, seed, dtype]
kernel :
func : gaussian_random
param : [shape, mean, std, seed, dtype]
data_type : dtype
backend : place

- api : gelu
args : (Tensor x, bool approximate)
output : Tensor(out)
Expand Down

0 comments on commit 4819ab4

Please sign in to comment.