Skip to content

Commit

Permalink
add paddle.incubate.sparse.is_same_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
gsq7474741 committed Aug 16, 2022
1 parent e6bf0bf commit c4c6db0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 34 deletions.
Expand Up @@ -16,10 +16,10 @@
import unittest

import paddle
from paddle.incubate.sparse.multiary import is_same_shape
from paddle.incubate.sparse.binary import is_same_shape


class TestSparseMultiaryAPI(unittest.TestCase):
class TestSparseIsSameShapeAPI(unittest.TestCase):
"""
test paddle.incubate.sparse.is_same_shape
"""
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/sparse/__init__.py
Expand Up @@ -42,9 +42,9 @@
from .binary import divide
from .binary import multiply
from .binary import subtract
from .binary import is_same_shape

from .multiary import addmm
from .multiary import is_same_shape

from . import nn

Expand Down
32 changes: 32 additions & 0 deletions python/paddle/incubate/sparse/binary.py
Expand Up @@ -399,3 +399,35 @@ def divide(x, y, name=None):
if y.dtype != x.dtype:
y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
return _C_ops.final_state_sparse_divide(x, y)


@dygraph_only
def is_same_shape(x, y):
"""
Check whether x.shape equal to y.shape.
Args:
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
Returns:
bool: True for same shape and False for different shape.
Examples:
.. code-block:: python
import paddle
x = paddle.rand([2, 3, 8])
y = paddle.rand([2, 3, 8])
y = y.to_sparse_csr()
z = paddle.rand([2, 5])
paddle.incubate.sparse.is_same_shape(x, y)
# True
paddle.incubate.sparse.is_same_shape(x, z)
# False
"""
return x.is_same_shape(y)
31 changes: 0 additions & 31 deletions python/paddle/incubate/sparse/multiary.py
Expand Up @@ -79,34 +79,3 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
"""
return _C_ops.final_state_sparse_addmm(input, x, y, alpha, beta)


@dygraph_only
def is_same_shape(x, y):
"""
Check whether x.shape equal to y.shape.
Args:
x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
Returns:
bool: True for same shape and False for different shape.
Examples:
.. code-block:: python
import paddle
x = paddle.rand([2, 3, 8])
y = paddle.rand([2, 3, 8])
z = paddle.rand([2, 5])
paddle.incubate.sparse.is_same_shape(x, y)
# True
paddle.incubate.sparse.is_same_shape(x, z)
# False
"""
return x.is_same_shape(y)

0 comments on commit c4c6db0

Please sign in to comment.