Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Aug 23, 2022
1 parent db8a233 commit e932c68
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion python/paddle/fluid/tests/unittests/test_inplace.py
Expand Up @@ -483,7 +483,16 @@ def inplace_api_processing(self, var):
return var.subtract_(input_var_2)


class TestDygraphInplaceRemainder(TestDygraphInplaceAdd):
class TestDygraphInplaceRemainder(unittest.TestCase):

def setUp(self):
self.init_data()
self.set_np_compare_func()

def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 4)
self.dtype = "float32"
self.input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)

def non_inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
Expand All @@ -493,6 +502,41 @@ def inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.remainder_(input_var_2)

def set_np_compare_func(self):
self.np_compare = np.array_equal

def func_test_inplace_api(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
inplace_var = self.inplace_api_processing(var)
self.assertTrue(id(var) == id(inplace_var))

inplace_var[0] = 2.
np.testing.assert_array_equal(var.numpy(), inplace_var.numpy())

def test_inplace_api(self):
with _test_eager_guard():
self.func_test_inplace_api()
self.func_test_inplace_api()

def func_test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
self.assertEqual(var.inplace_version, 0)

inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 1)

inplace_var[0] = 2.
self.assertEqual(var.inplace_version, 2)

inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 3)

def test_forward_version(self):
with _test_eager_guard():
self.func_test_forward_version()
self.func_test_forward_version()


class TestLossIsInplaceVar(unittest.TestCase):

Expand Down

0 comments on commit e932c68

Please sign in to comment.