Skip to content

Commit

Permalink
Cherry pick final state ops (PaddlePaddle#41755)
Browse files Browse the repository at this point in the history
* [Yaml]add exp yaml (PaddlePaddle#41217)

* add exp yaml

* add exp api in test case

* add determinant yaml

* fix exp op unittest

* change test class name

* modify api name

* compacted with raw api

* fix det api

* add python_api

* add test eager for determinant op

* [Yaml] Add assign yaml (PaddlePaddle#41428)

* add assign yaml

* add assign api

* add assign backward api

* add assign

* add assign yaml

* add assign

* assign yaml

* add assign raw kernel and use assign_raw in yaml

* merge develop branch

* add missing python_api

* exchange assign and assign_raw kernel name (PaddlePaddle#41625)

* exchange assign and assign_raw kernel name

* fix register error

* [Yaml]add gaussian_random yaml and test case (PaddlePaddle#41312)

* add guassian random yaml

* add gaussian_random yaml and test case

* fix error modify of full yaml

* import in_dygraph_mode

* import _in_legacy_dygraph

* add place arg in api

* import __current_expected_place

* fix test_egr_python_api failed case

* add test case

* add cast for NormalInitializer

* fix test error

* fix test error

* rm unsed check code

* fix test error in test_initializer_nn

* modify by review

* [Phi]fix split error when sections has 0 size and add test case (PaddlePaddle#41708)

* fix split error when sections has 0 size and add test case

* fix test case
  • Loading branch information
MingMingShangTian committed Apr 14, 2022
1 parent 4d94eac commit 921a6fb
Show file tree
Hide file tree
Showing 20 changed files with 320 additions and 47 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/strided_memcpy.h
Expand Up @@ -134,7 +134,7 @@ inline void StridedMemcpyWithAxis0(
for (size_t i = 0; i < outputs->size(); ++i) {
auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i);
if (out != nullptr) {
if (out != nullptr && out->initialized()) {
StridedNumelCopyWithAxis<T>(dev_ctx, axis, out->data<T>(), out_stride,
input.data<T>() + input_offset, in_stride,
out_stride[axis]);
Expand Down
30 changes: 24 additions & 6 deletions paddle/phi/kernels/assign_kernel.cc
Expand Up @@ -24,14 +24,21 @@ namespace phi {

template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
const DenseTensor& x,
DenseTensor* out) {
if (x.get_ptr()) {
if (!x.is_initialized()) {
Copy<Context>(dev_ctx, x, x.place(), false, out);
}

template <typename Context>
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (x) {
if (!x->IsInitialized()) {
return;
}
auto& x_tensor = *x.get_ptr();
Copy<Context>(dev_ctx, x_tensor, x_tensor.place(), false, out);
AssignKernel<Context>(dev_ctx, x_tensor, out);
}
}

Expand Down Expand Up @@ -105,7 +112,13 @@ void AssignValueKernel(const Context& dev_ctx,
} // namespace phi

PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}

PD_REGISTER_GENERAL_KERNEL(assign_raw,
CPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::CPUContext>,
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
Expand All @@ -124,7 +137,12 @@ PD_REGISTER_KERNEL(assign_value,

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw,
GPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::GPUContext>,
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
Expand Down
11 changes: 8 additions & 3 deletions paddle/phi/kernels/assign_kernel.h
Expand Up @@ -21,13 +21,18 @@

namespace phi {

template <typename Context>
void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

// In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);

template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/assign_sig.cc
Expand Up @@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
} else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
}

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/fluid/dygraph/varbase_patch_methods.py
Expand Up @@ -23,7 +23,7 @@
from ..framework import convert_np_dtype_to_dtype_, _in_legacy_dygraph
from .. import core
from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase, in_dygraph_mode
from .base import switch_to_static_graph
from .math_op_patch import monkey_patch_math_varbase
from .parallel import scale_loss
Expand Down Expand Up @@ -798,6 +798,9 @@ def _set_grad_ivar(self, value):

@framework.dygraph_only
def clone(self):
if in_dygraph_mode():
return _C_ops.final_state_assign(self)

if _in_legacy_dygraph():
output = core.VarBase()
else:
Expand Down
79 changes: 66 additions & 13 deletions python/paddle/fluid/initializer.py
Expand Up @@ -331,29 +331,70 @@ def __call__(self, var, block=None):
["uint16", "float16", "float32", "float64"],
"guassian_random")

# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['normal_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var

if self._seed == 0:
self._seed = block.program.random_seed

if framework._non_static_mode():
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
var.shape, self._mean, self._std_dev, self._seed, out_dtype,
place)
out_var._share_underline_tensor_to(var)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.final_state_cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
return None

if _in_legacy_dygraph():
out_var = _C_ops.gaussian_random(
'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean,
'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean,
'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False)
out_var._share_underline_tensor_to(var)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var_tmp._share_underline_tensor_to(var)
else:
out_var._share_underline_tensor_to(var)
return None
else:
op = block.append_op(
type="gaussian_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": var.dtype,
"dtype": out_dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
"use_mkldnn": False
},
stop_gradient=True)

if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op

Expand Down Expand Up @@ -566,10 +607,16 @@ def __call__(self, var, block=None):
-limit, 'max', limit, 'seed',
self._seed, 'dtype', out_dtype)
else:
std = np.sqrt(2.0 / float(fan_in + fan_out))
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0,
'std', std, 'seed', self._seed)
std = math.sqrt(2.0 / float(fan_in + fan_out))

if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
out_var.shape, 0.0, std, self._seed, out_dtype, place)
else:
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0,
'std', std, 'seed', self._seed)

if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
Expand Down Expand Up @@ -719,10 +766,16 @@ def __call__(self, var, block=None):
self._seed, 'dtype',
int(out_dtype))
else:
std = np.sqrt(2.0 / float(fan_in))
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype',
int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed)
std = math.sqrt(2.0 / float(fan_in))
if in_dygraph_mode():
place = _current_expected_place()
out_var = _C_ops.final_state_gaussian_random(
out_var.shape, 0.0, std, self._seed, out_dtype, place)
else:
out_var = _C_ops.gaussian_random(
'shape', out_var.shape, 'dtype',
int(out_dtype), 'mean', 0.0, 'std', std, 'seed',
self._seed)

if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform):
Expand Down
11 changes: 10 additions & 1 deletion python/paddle/fluid/layers/nn.py
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid.framework import _in_legacy_dygraph
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, _non_static_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator, static_only, _global_flags, _in_legacy_dygraph, in_dygraph_mode
from ..framework import _current_expected_place
from .. import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
Expand Down Expand Up @@ -10964,7 +10965,15 @@ def gaussian_random(shape,
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if _non_static_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
place = _current_expected_place()
return _C_ops.final_state_gaussian_random(shape,
float(mean),
float(std), seed, dtype,
place)

if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
return _C_ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
Expand Down
15 changes: 9 additions & 6 deletions python/paddle/fluid/layers/tensor.py
Expand Up @@ -622,12 +622,15 @@ def assign(input, output=None):
# after this api.
if isinstance(input, (Variable, core.VarBase)):
if _non_static_mode():
if output is None:
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
_C_ops.assign(input, output)
if in_dygraph_mode() and output is None:
output = _C_ops.final_state_assign(input)
else:
if output is None:
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
_C_ops.assign(input, output)
else:
check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64',
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/test_activation_op.py
Expand Up @@ -51,7 +51,8 @@ def setUp(self):
self.op_type = "exp"
self.init_dtype()
self.init_kernel_type()
self.check_eager = False
self.check_eager = True
self.python_api = paddle.exp

np.random.seed(2049)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
Expand Down
12 changes: 8 additions & 4 deletions python/paddle/fluid/tests/unittests/test_assign_op.py
Expand Up @@ -27,30 +27,32 @@

class TestAssignOp(op_test.OpTest):
def setUp(self):
self.python_api = paddle.assign
self.op_type = "assign"
x = np.random.random(size=(100, 10)).astype('float64')
self.inputs = {'X': x}
self.outputs = {'Out': x}

def test_forward(self):
self.check_output()
self.check_output(check_eager=True)

def test_backward(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestAssignFP16Op(op_test.OpTest):
def setUp(self):
self.python_api = paddle.assign
self.op_type = "assign"
x = np.random.random(size=(100, 10)).astype('float16')
self.inputs = {'X': x}
self.outputs = {'Out': x}

def test_forward(self):
self.check_output()
self.check_output(check_eager=True)

def test_backward(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestAssignOpWithLoDTensorArray(unittest.TestCase):
Expand Down Expand Up @@ -171,6 +173,8 @@ def test_assign_BasicTypes(self):

def test_clone(self):
paddle.disable_static()
self.python_api = paddle.clone

x = paddle.ones([2])
x.stop_gradient = False
clone_x = paddle.clone(x)
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/fluid/tests/unittests/test_determinant_op.py
Expand Up @@ -22,21 +22,23 @@
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.tensor as tensor
from paddle.fluid.framework import _test_eager_guard

paddle.enable_static()


class TestDeterminantOp(OpTest):
def setUp(self):
self.python_api = paddle.linalg.det
self.init_data()
self.op_type = "determinant"
self.outputs = {'Out': self.target}

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['Input'], ['Out'])
self.check_grad(['Input'], ['Out'], check_eager=True)

def init_data(self):
np.random.seed(0)
Expand Down Expand Up @@ -89,6 +91,10 @@ def test_api_dygraph(self):
self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True)
paddle.enable_static()

def test_eager(self):
with _test_eager_guard():
self.test_api_dygraph()


class TestSlogDeterminantOp(OpTest):
def setUp(self):
Expand Down
3 changes: 0 additions & 3 deletions python/paddle/fluid/tests/unittests/test_egr_python_api.py
Expand Up @@ -251,9 +251,6 @@ def constructor(self, place):
self.assertTrue(egr_tensor12.place._equals(paddle.fluid.CPUPlace()))
self.assertTrue(np.array_equal(egr_tensor12.numpy(), x))

egr_tensor13 = paddle.randn([2, 2])
self.assertTrue("eager_tmp" in egr_tensor13.name)

with self.assertRaisesRegexp(
ValueError, "The shape of Parameter should not be None"):
eager_param = EagerParamBase(shape=None, dtype="float32")
Expand Down

0 comments on commit 921a6fb

Please sign in to comment.