Skip to content

Commit

Permalink
add more unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Apr 7, 2022
1 parent aa24a86 commit 2bb7885
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
67 changes: 64 additions & 3 deletions python/paddle/fluid/tests/unittests/test_frac_api.py
Expand Up @@ -20,19 +20,26 @@
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard


def ref_frac(x):
return x - np.trunc(x)


class TestFracAPI(unittest.TestCase):
"""Test Frac API"""

def set_dtype(self):
self.dtype = 'float64'

def setUp(self):
self.x_np = np.random.uniform(-3, 3, [2, 3]).astype('float64')
self.set_dtype()
self.x_np = np.random.uniform(-3, 3, [2, 3]).astype(self.dtype)
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_api(self):
def test_api_static(self):
paddle.enable_static()
with program_guard(Program()):
input = fluid.data('X', self.x_np.shape, self.x_np.dtype)
Expand All @@ -45,13 +52,67 @@ def test_static_api(self):
out_ref = ref_frac(self.x_np)
self.assertTrue(np.allclose(out_ref, res))

def test__dygraph_api(self):
def test_api_dygraph(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out = paddle.frac(x)
out_ref = ref_frac(self.x_np)
self.assertTrue(np.allclose(out_ref, out.numpy()))

def test_api_eager(self):
paddle.disable_static(self.place)
with _test_eager_guard():
x_tensor = paddle.to_tensor(self.x_np)
out = paddle.frac(x_tensor)
out_ref = ref_frac(self.x_np)
self.assertTrue(np.allclose(out_ref, out.numpy()))
paddle.enable_static()

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_api_dygraph()


class TestFracInt32(TestFracAPI):
"""Test Frac API with data type int32"""

def set_dtype(self):
self.dtype = 'int32'


class TestFracInt64(TestFracAPI):
"""Test Frac API with data type int64"""

def set_dtype(self):
self.dtype = 'int64'


class TestFracFloat32(TestFracAPI):
"""Test Frac API with data type float32"""

def set_dtype(self):
self.dtype = 'float32'


class TestFracError(unittest.TestCase):
"""Test Frac Error"""

def setUp(self):
self.x_np = np.random.uniform(-3, 3, [2, 3]).astype('int16')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()

def test_static_error(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [5, 5], 'bool')
self.assertRaises(TypeError, paddle.frac, x)

def test_dygraph_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np, dtype='int16')
self.assertRaises(TypeError, paddle.frac, x)


if __name__ == '__main__':
unittest.main()
5 changes: 4 additions & 1 deletion python/paddle/tensor/math.py
Expand Up @@ -3960,9 +3960,12 @@ def frac(x, name=None):
op_type = 'elementwise_sub'
axis = -1
act = None
if x.dtype not in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]:
raise TypeError(
"The data type of input must be one of ['int32', 'int64', 'float32', 'float64'], but got {}".format(x.dtype))
if in_dygraph_mode():
y = _C_ops.final_state_trunc(x)
return _C_ops.final_state_substract(x, y)
return _C_ops.final_state_subtract(x, y)
else:
if _in_legacy_dygraph():
y = _C_ops.trunc(x)
Expand Down

0 comments on commit 2bb7885

Please sign in to comment.