Skip to content

Commit

Permalink
add remainder_ op
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Aug 19, 2022
1 parent e654f1e commit b45b89e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 1 deletion.
3 changes: 2 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_mod_op.cc
Expand Up @@ -56,7 +56,8 @@ class ElementwiseModOpMaker : public ElementwiseOpMaker {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod,
ops::ElementwiseOp,
ops::ElementwiseModOpMaker);
ops::ElementwiseModOpMaker,
ops::ElementwiseOpInplaceInferer);

REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint(
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py
Expand Up @@ -133,5 +133,43 @@ def test_dygraph(self):
np.testing.assert_allclose(z_expected, z.numpy(), rtol=1e-05)


class TestRemainderInplaceOp(unittest.TestCase):

def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="int64")
y = fluid.data(name='y', shape=[2, 3], dtype='int64')

y_1 = paddle.remainder_(x, y, name='div_res')
self.assertEqual(('div_res' in y_1.name), True)

def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 8, 7]).astype('int64')
np_y = np.array([1, 5, 3, 3]).astype('int64')
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = paddle.remainder_(x, y)
np_z = z.numpy()
z_expected = np.array([0, 3, 2, 1])
self.assertEqual((np_z == z_expected).all(), True)

np_x = np.array([-3.3, 11.5, -2, 3.5])
np_y = np.array([-1.2, 2., 3.3, -2.3])
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = x % y
z_expected = np.array([-0.9, 1.5, 1.3, -1.1])
np.testing.assert_allclose(z_expected, z.numpy(), rtol=1e-05)

np_x = np.array([-3, 11, -2, 3])
np_y = np.array([-1, 2, 3, -2])
x = paddle.to_tensor(np_x, dtype="int64")
y = paddle.to_tensor(np_y, dtype="int64")
z = x % y
z_expected = np.array([0, 1, 1, -1])
np.testing.assert_allclose(z_expected, z.numpy(), rtol=1e-05)


if __name__ == '__main__':
unittest.main()
11 changes: 11 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Expand Up @@ -483,6 +483,17 @@ def inplace_api_processing(self, var):
return var.subtract_(input_var_2)


class TestDygraphInplaceRemainder(TestDygraphInplaceAdd):

def non_inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.remainder(input_var_2)

def inplace_api_processing(self, var):
input_var_2 = paddle.to_tensor(self.input_var_numpy_2)
return var.remainder_(input_var_2)


class TestLossIsInplaceVar(unittest.TestCase):

def func_test_loss_is_inplace_var(self):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Expand Up @@ -182,6 +182,7 @@
from .math import divide # noqa: F401
from .math import floor_divide # noqa: F401
from .math import remainder # noqa: F401
from .math import remainder_ # noqa: F401
from .math import mod # noqa: F401
from .math import floor_mod # noqa: F401
from .math import multiply # noqa: F401
Expand Down Expand Up @@ -364,6 +365,7 @@
'divide',
'floor_divide',
'remainder',
'remainder_',
'mod',
'floor_mod',
'multiply',
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/tensor/math.py
Expand Up @@ -779,6 +779,23 @@ def remainder(x, y, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))


def remainder_(x, y, name=None):
r"""
Inplace version of ``remainder`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_remainder`.
"""
op_type = 'elementwise_mod_'
axis = -1

out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape))

return _elementwise_op_in_dygraph(x, y, axis=axis, op_name=op_type)


mod = remainder # noqa: F841
floor_mod = remainder # noqa: F841

Expand Down

0 comments on commit b45b89e

Please sign in to comment.